diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal')
28 files changed, 6240 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/client/client.go b/vendor/github.com/authzed/zed/internal/client/client.go new file mode 100644 index 0000000..17ddb11 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/client/client.go @@ -0,0 +1,303 @@ +package client + +import ( + "context" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" + "github.com/jzelinskie/cobrautil/v2" + "github.com/mitchellh/go-homedir" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/net/proxy" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/authzed-go/v1" + "github.com/authzed/grpcutil" + + zgrpcutil "github.com/authzed/zed/internal/grpcutil" + "github.com/authzed/zed/internal/storage" +) + +// Client defines an interface for making calls to SpiceDB API. +type Client interface { + v1.SchemaServiceClient + v1.PermissionsServiceClient + v1.WatchServiceClient + v1.ExperimentalServiceClient +} + +const ( + defaultRetryExponentialBackoff = 100 * time.Millisecond + defaultMaxRetryAttemptDuration = 2 * time.Second + defaultRetryJitterFraction = 0.5 + bulkImportRoute = "/authzed.api.v1.ExperimentalService/BulkImportRelationships" + importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships" +) + +// NewClient defines an (overridable) means of creating a new client. +var ( + NewClient = newClientForCurrentContext + NewClientForContext = newClientForContext +) + +func newClientForCurrentContext(cmd *cobra.Command) (Client, error) { + configStore, secretStore := DefaultStorage() + token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return nil, err + } + + dialOpts, err := DialOptsFromFlags(cmd, token) + if err != nil { + return nil, err + } + + if cobrautil.MustGetString(cmd, "proxy") != "" { + token.Endpoint = withPassthroughTarget(token.Endpoint) + } + + client, err := authzed.NewClientWithExperimentalAPIs(token.Endpoint, dialOpts...) + if err != nil { + return nil, err + } + + return client, err +} + +func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) { + currentToken, err := storage.GetTokenIfExists(contextName, secretStore) + if err != nil { + return nil, err + } + + token, err := GetTokenWithCLIOverride(cmd, currentToken) + if err != nil { + return nil, err + } + + dialOpts, err := DialOptsFromFlags(cmd, token) + if err != nil { + return nil, err + } + + if cobrautil.MustGetString(cmd, "proxy") != "" { + token.Endpoint = withPassthroughTarget(token.Endpoint) + } + + return authzed.NewClient(token.Endpoint, dialOpts...) +} + +// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args +func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) { + // Handle the no-config case separately + configExists, err := configStore.Exists() + if err != nil { + return storage.Token{}, err + } + if !configExists { + return GetTokenWithCLIOverride(cmd, storage.Token{}) + } + token, err := storage.CurrentToken( + configStore, + secretStore, + ) + if err != nil { + return storage.Token{}, err + } + + return GetTokenWithCLIOverride(cmd, token) +} + +// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command +// flags +func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) { + overrideToken, err := tokenFromCli(cmd) + if err != nil { + return storage.Token{}, err + } + + result, err := storage.TokenWithOverride( + overrideToken, + token, + ) + if err != nil { + return storage.Token{}, err + } + + log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send() + return result, nil +} + +func tokenFromCli(cmd *cobra.Command) (storage.Token, error) { + certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path") + var certBytes []byte + var err error + if certPath != "" { + certBytes, err = os.ReadFile(certPath) + if err != nil { + return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err) + } + } + + explicitInsecure := cmd.Flags().Changed("insecure") + var notSecure *bool + if explicitInsecure { + i := cobrautil.MustGetBool(cmd, "insecure") + notSecure = &i + } + + explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca") + var notVerifyCA *bool + if explicitNoVerifyCA { + nvc := cobrautil.MustGetBool(cmd, "no-verify-ca") + notVerifyCA = &nvc + } + overrideToken := storage.Token{ + APIToken: cobrautil.MustGetString(cmd, "token"), + Endpoint: cobrautil.MustGetString(cmd, "endpoint"), + Insecure: notSecure, + NoVerifyCA: notVerifyCA, + CACert: certBytes, + } + return overrideToken, nil +} + +// DefaultStorage returns the default configured config store and secret store. +var DefaultStorage = defaultStorage + +func defaultStorage() (storage.ConfigStore, storage.SecretStore) { + var home string + if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" { + home = filepath.Join(xdg, "zed") + } else { + hmdir, _ := homedir.Dir() + home = filepath.Join(hmdir, ".zed") + } + return &storage.JSONConfigStore{ConfigPath: home}, + &storage.KeychainSecretStore{ConfigPath: home} +} + +func certOption(token storage.Token) (opt grpc.DialOption, err error) { + verification := grpcutil.VerifyCA + if token.HasNoVerifyCA() { + verification = grpcutil.SkipVerifyCA + } + + if certBytes, ok := token.Certificate(); ok { + return grpcutil.WithCustomCertBytes(verification, certBytes) + } + + return grpcutil.WithSystemCerts(verification) +} + +func isNoneOf(routes ...string) func(_ context.Context, c interceptors.CallMeta) bool { + return func(_ context.Context, c interceptors.CallMeta) bool { + for _, route := range routes { + if route == c.FullMethod() { + return false + } + } + return true + } +} + +// DialOptsFromFlags returns the dial options from the CLI-specified flags. +func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) { + maxRetries := cobrautil.MustGetUint(cmd, "max-retries") + retryOpts := []retry.CallOption{ + retry.WithBackoff(retry.BackoffExponentialWithJitterBounded(defaultRetryExponentialBackoff, + defaultRetryJitterFraction, defaultMaxRetryAttemptDuration)), + retry.WithCodes(codes.ResourceExhausted, codes.Unavailable, codes.Aborted, codes.Unknown, codes.Internal), + retry.WithMax(maxRetries), + retry.WithOnRetryCallback(func(_ context.Context, attempt uint, err error) { + log.Error().Err(err).Uint("attempt", attempt).Msg("retrying gRPC call") + }), + } + unaryInterceptors := []grpc.UnaryClientInterceptor{ + zgrpcutil.LogDispatchTrailers, + retry.UnaryClientInterceptor(retryOpts...), + } + + streamInterceptors := []grpc.StreamClientInterceptor{ + zgrpcutil.StreamLogDispatchTrailers, + selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(bulkImportRoute, importBulkRoute))), + } + + if !cobrautil.MustGetBool(cmd, "skip-version-check") { + unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion) + } + + opts := []grpc.DialOption{ + grpc.WithChainUnaryInterceptor(unaryInterceptors...), + grpc.WithChainStreamInterceptor(streamInterceptors...), + } + + proxyAddr := cobrautil.MustGetString(cmd, "proxy") + + if proxyAddr != "" { + addr, err := url.Parse(proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse socks5 proxy addr: %w", err) + } + + dialer, err := proxy.SOCKS5("tcp", addr.Host, nil, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("failed to create socks5 proxy dialer: %w", err) + } + + opts = append(opts, grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) { + return dialer.Dial("tcp", addr) + })) + } + + if token.IsInsecure() { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken)) + } else { + opts = append(opts, grpcutil.WithBearerToken(token.APIToken)) + certOpt, err := certOption(token) + if err != nil { + return nil, fmt.Errorf("failed to configure TLS cert: %w", err) + } + opts = append(opts, certOpt) + } + + hostnameOverride := cobrautil.MustGetString(cmd, "hostname-override") + if hostnameOverride != "" { + opts = append(opts, grpc.WithAuthority(hostnameOverride)) + } + + maxMessageSize := cobrautil.MustGetInt(cmd, "max-message-size") + if maxMessageSize != 0 { + opts = append(opts, grpc.WithDefaultCallOptions( + // The default max client message size is 4mb. + // It's conceivable that a sufficiently complex + // schema will easily surpass this, so we set the + // limit higher here. + grpc.MaxCallRecvMsgSize(maxMessageSize), + grpc.MaxCallSendMsgSize(maxMessageSize), + )) + } + + return opts, nil +} + +func withPassthroughTarget(endpoint string) string { + // If it already has a scheme, return as-is + if strings.Contains(endpoint, "://") { + return endpoint + } + return "passthrough:///" + endpoint +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/backup.go b/vendor/github.com/authzed/zed/internal/cmd/backup.go new file mode 100644 index 0000000..57fa83c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/backup.go @@ -0,0 +1,868 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/mattn/go-isatty" + "github.com/rodaine/table" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + schemapkg "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/pkg/backupformat" +) + +const ( + returnIfExists = true + doNotReturnIfExists = false +) + +// cobraRunEFunc is the signature of a cobra.Command.RunE function. +type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) + +// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic. +func withErrorHandling(f cobraRunEFunc) cobraRunEFunc { + return func(cmd *cobra.Command, args []string) (err error) { + return addSizeErrInfo(f(cmd, args)) + } +} + +var ( + backupCmd = &cobra.Command{ + Use: "backup <filename>", + Short: "Create, restore, and inspect permissions system backups", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + // Create used to be on the root, so add it here for back-compat. + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupCreateCmd = &cobra.Command{ + Use: "create <filename>", + Short: "Backup a permission system to a file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupRestoreCmd = &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a file", + Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)), + RunE: backupRestoreCmdFunc, + } + + backupParseSchemaCmd = &cobra.Command{ + Use: "parse-schema <filename>", + Short: "Extract the schema from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseSchemaCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRevisionCmd = &cobra.Command{ + Use: "parse-revision <filename>", + Short: "Extract the revision from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRevisionCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRelsCmd = &cobra.Command{ + Use: "parse-relationships <filename>", + Short: "Extract the relationships from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRelsCmdFunc(cmd, os.Stdout, args) + }, + } + + backupRedactCmd = &cobra.Command{ + Use: "redact <filename>", + Short: "Redact a backup file to remove sensitive information", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupRedactCmdFunc(cmd, args) + }, + } +) + +func registerBackupCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(backupCmd) + registerBackupCreateFlags(backupCmd) + + backupCmd.AddCommand(backupCreateCmd) + registerBackupCreateFlags(backupCreateCmd) + + backupCmd.AddCommand(backupRestoreCmd) + registerBackupRestoreFlags(backupRestoreCmd) + + backupCmd.AddCommand(backupRedactCmd) + backupRedactCmd.Flags().Bool("redact-definitions", true, "redact definitions") + backupRedactCmd.Flags().Bool("redact-relations", true, "redact relations") + backupRedactCmd.Flags().Bool("redact-object-ids", true, "redact object IDs") + backupRedactCmd.Flags().Bool("print-redacted-object-ids", false, "prints the redacted object IDs") + + // Restore used to be on the root, so add it there too, but hidden. + restoreCmd := &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a backup file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: backupRestoreCmdFunc, + Hidden: true, + } + rootCmd.AddCommand(restoreCmd) + registerBackupRestoreFlags(restoreCmd) + + backupCmd.AddCommand(backupParseSchemaCmd) + backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + backupParseSchemaCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + + backupCmd.AddCommand(backupParseRevisionCmd) + backupCmd.AddCommand(backupParseRelsCmd) + backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix") +} + +func registerBackupRestoreFlags(cmd *cobra.Command) { + cmd.Flags().Uint("batch-size", 1_000, "restore relationship write batch size") + cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction") + cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch") + cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)") + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore") +} + +func registerBackupCreateFlags(cmd *cobra.Command) { + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") +} + +func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { + if filename == "-" { + log.Trace().Str("filename", "- (stdout)").Send() + return os.Stdout, false, nil + } + + log.Trace().Str("filename", filename).Send() + + if _, err := os.Stat(filename); err == nil { + if !returnIfExists { + return nil, false, fmt.Errorf("backup file already exists: %s", filename) + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to open existing backup file: %w", err) + } + + return f, true, nil + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to create backup file: %w", err) + } + + return f, false, nil +} + +var ( + missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`) + shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`) +) + +func partialPrefixMatch(name, prefix string) bool { + return strings.HasPrefix(name, prefix+"/") +} + +func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) { + if schema == "" || prefix == "" { + return schema, nil + } + + compiledSchema, err := compiler.Compile( + compiler.InputSchema{Source: "schema", SchemaString: schema}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", fmt.Errorf("error reading schema: %w", err) + } + + var prefixedDefs []compiler.SchemaDefinition + for _, def := range compiledSchema.ObjectDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + for _, def := range compiledSchema.CaveatDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + + if len(prefixedDefs) == 0 { + return "", errors.New("filtered all definitions from schema") + } + + filteredSchema, _, err = generator.GenerateSchema(prefixedDefs) + if err != nil { + return "", fmt.Errorf("error generating filtered schema: %w", err) + } + + // Validate that the type system for the generated schema is comprehensive. + compiledFilteredSchema, err := compiler.Compile( + compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + + for _, rawDef := range compiledFilteredSchema.ObjectDefinitions { + ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema)) + def, err := schemapkg.NewDefinition(ts, rawDef) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + if _, err := def.Validate(context.Background()); err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + } + + return +} + +func hasRelPrefix(rel *v1.Relationship, prefix string) bool { + // Skip any relationships without the prefix on both sides. + return strings.HasPrefix(rel.Resource.ObjectType, prefix) && + strings.HasPrefix(rel.Subject.Object.ObjectType, prefix) +} + +func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + + backupFileName, err := computeBackupFileName(cmd, args) + if err != nil { + return err + } + + backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists) + if err != nil { + return err + } + + defer func(e *error) { + *e = errors.Join(*e, backupFile.Sync()) + *e = errors.Join(*e, backupFile.Close()) + }(&err) + + // the goal of this file is to keep the bulk export cursor in case the process is terminated + // and we need to resume from where we left off. OCF does not support in-place record updates. + progressFile, cursor, err := openProgressFile(backupFileName, backupExists) + if err != nil { + return err + } + + var backupCompleted bool + defer func(e *error) { + *e = errors.Join(*e, progressFile.Sync()) + *e = errors.Join(*e, progressFile.Close()) + + if backupCompleted { + if err := os.Remove(progressFile.Name()); err != nil { + log.Warn(). + Str("progress-file", progressFile.Name()). + Msg("failed to remove progress file, consider removing it manually") + } + } + }(&err) + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + var zedToken *v1.ZedToken + var encoder *backupformat.Encoder + if backupExists { + encoder, err = backupformat.NewEncoderForExisting(backupFile) + if err != nil { + return fmt.Errorf("error creating backup file encoder: %w", err) + } + } else { + encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) + if err != nil { + return err + } + } + + defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + + if zedToken == nil && cursor == nil { + return errors.New("malformed existing backup, consider recreating it") + } + + req := &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } + + // if a cursor is present, zedtoken is not needed (it is already in the cursor) + if zedToken != nil { + req.Consistency = &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: zedToken, + }, + } + } + + ctx := cmd.Context() + relationshipStream, err := c.BulkExportRelationships(ctx, req) + if err != nil { + return fmt.Errorf("error exporting relationships: %w", err) + } + + relationshipReadStart := time.Now() + tick := time.Tick(5 * time.Second) + bar := console.CreateProgressBar("processing backup") + var relsFilteredOut, relsProcessed uint64 + defer func() { + _ = bar.Finish() + + evt := log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)) + if isCanceled(err) { + evt.Msg("backup canceled - resume by restarting the backup command") + } else if err != nil { + evt.Msg("backup failed") + } else { + evt.Msg("finished backup") + } + }() + + for { + if err := ctx.Err(); err != nil { + if isCanceled(err) { + return context.Canceled + } + + return fmt.Errorf("aborted backup: %w", err) + } + + relsResp, err := relationshipStream.Recv() + if err != nil { + if isCanceled(err) { + return context.Canceled + } + + if !errors.Is(err, io.EOF) { + return fmt.Errorf("error receiving relationships: %w", err) + } + break + } + + for _, rel := range relsResp.Relationships { + if hasRelPrefix(rel, prefixFilter) { + if err := encoder.Append(rel); err != nil { + return fmt.Errorf("error storing relationship: %w", err) + } + } else { + relsFilteredOut++ + } + + relsProcessed++ + if err := bar.Add(1); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). + Msg("backup progress") + default: + } + } + } + + if err := writeProgress(progressFile, relsResp); err != nil { + return err + } + } + + backupCompleted = true + return nil +} + +// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup +// must be taken. +func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + + schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) + if err != nil { + return nil, nil, fmt.Errorf("error reading schema: %w", err) + } + if schemaResp.ReadAt == nil { + return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") + } + schema := schemaResp.SchemaText + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return nil, nil, err + } + } + + zedToken := schemaResp.ReadAt + + encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken) + if err != nil { + return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err) + } + + return encoder, zedToken, nil +} + +func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error { + err := progressFile.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + _, err = progressFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) + if err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + return nil +} + +// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates +// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. +// +// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume +// backups in case of failure. +func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { + var cursor *v1.Cursor + progressFileName := toLockFileName(backupFileName) + var progressFile *os.File + // if a backup existed + var fileMode int + readCursor, err := os.ReadFile(progressFileName) + if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) { + return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) + } else if backupAlreadyExisted && err == nil { + cursor = &v1.Cursor{ + Token: string(readCursor), + } + + // if backup existed and there is a progress marker, the latter should not be truncated to make sure the + // cursor stays around in case of a failure before we even start ingesting from bulk export + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, nil, err + } + + return progressFile, cursor, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +// computeBackupFileName computes the backup file name based. +// If no file name is provided, it derives a backup on the current context +func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) { + if len(args) > 0 { + return args[0], nil + } + + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return "", fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + return "", err + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + + return backupFileName, nil +} + +func isCanceled(err error) bool { + if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled { + return true + } + + return errors.Is(err, context.Canceled) +} + +func openRestoreFile(filename string) (*os.File, int64, error) { + if filename == "" { + log.Trace().Str("filename", "(stdin)").Send() + return os.Stdin, -1, nil + } + + log.Trace().Str("filename", filename).Send() + + stats, err := os.Stat(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to stat restore file: %w", err) + } + + f, err := os.Open(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to open restore file: %w", err) + } + + return f, stats.Size(), nil +} + +func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + if loadedToken := decoder.ZedToken(); loadedToken != nil { + log.Debug().Str("revision", loadedToken.Token).Msg("parsed revision") + } + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema") + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + batchSize := cobrautil.MustGetUint(cmd, "batch-size") + batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction") + + strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping) + if err != nil { + return err + } + disableRetries := cobrautil.MustGetBool(cmd, "disable-retries") + requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout") + + return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy, + disableRetries, requestTimeout).restoreFromDecoder(cmd.Context()) +} + +// GetEnum is a helper for getting an enum value from a string cobra flag. +func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) { + value := cobrautil.MustGetString(cmd, name) + value = strings.TrimSpace(strings.ToLower(value)) + if enum, ok := mapping[value]; ok { + return enum, nil + } + + var zeroValueE E + return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping)) +} + +func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + if prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter"); prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + + _, err = fmt.Fprintln(out, schema) + return err +} + +func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + loadedToken := decoder.ZedToken() + if loadedToken == nil { + return fmt.Errorf("failed to parse decoded revision") + } + + _, err = fmt.Fprintln(out, loadedToken.Token) + return err +} + +func backupRedactCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return fmt.Errorf("error creating restore file decoder: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + filename := args[0] + ".redacted" + writer, _, err := createBackupFile(filename, doNotReturnIfExists) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, writer.Close()) }(&err) + + redactor, err := backupformat.NewRedactor(decoder, writer, backupformat.RedactionOptions{ + RedactDefinitions: cobrautil.MustGetBool(cmd, "redact-definitions"), + RedactRelations: cobrautil.MustGetBool(cmd, "redact-relations"), + RedactObjectIDs: cobrautil.MustGetBool(cmd, "redact-object-ids"), + }) + if err != nil { + return fmt.Errorf("error creating redactor: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, redactor.Close()) }(&err) + bar := console.CreateProgressBar("redacting backup") + var written int64 + for { + if err := cmd.Context().Err(); err != nil { + return fmt.Errorf("aborted redaction: %w", err) + } + + err := redactor.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("error redacting: %w", err) + } + + written++ + if err := bar.Set64(written); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + } + + if err := bar.Finish(); err != nil { + return fmt.Errorf("error finalizing progress bar: %w", err) + } + + fmt.Println("Redaction map:") + fmt.Println("--------------") + fmt.Println() + + // Draw a table of definitions, caveats and relations mapped. + tbl := table.New("Definition Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Definitions { + tbl.AddRow(k, v) + } + + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().Caveats) > 0 { + tbl = table.New("Caveat Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Caveats { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + tbl = table.New("Relation/Permission Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Relations { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().ObjectIDs) > 0 && cobrautil.MustGetBool(cmd, "print-redacted-object-ids") { + tbl = table.New("Object ID", "Redacted Object ID") + for k, v := range redactor.RedactionMap().ObjectIDs { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + return nil +} + +func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + prefix := cobrautil.MustGetString(cmd, "prefix-filter") + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() { + if !hasRelPrefix(rel, prefix) { + continue + } + + relString, err := tuple.V1StringRelationship(rel) + if err != nil { + return err + } + + if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil { + return err + } + } + + return nil +} + +func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) { + filename := "" // Default to stdin. + if len(args) > 0 { + filename = args[0] + } + + f, _, err := openRestoreFile(filename) + if err != nil { + return nil, nil, err + } + + decoder, err := backupformat.NewDecoder(f) + if err != nil { + return nil, nil, fmt.Errorf("error creating restore file decoder: %w", err) + } + + return decoder, f, nil +} + +func replaceRelString(rel string) string { + rel = strings.Replace(rel, "@", " ", 1) + return strings.Replace(rel, "#", " ", 1) +} + +func rewriteLegacy(schema string) string { + schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */"))) + return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */"))) +} + +var sizeErrorRegEx = regexp.MustCompile(`received message larger than max \((\d+) vs. (\d+)\)`) + +func addSizeErrInfo(err error) error { + if err == nil { + return nil + } + + code := status.Code(err) + if code != codes.ResourceExhausted { + return err + } + + if !strings.Contains(err.Error(), "received message larger than max") { + return err + } + + matches := sizeErrorRegEx.FindStringSubmatch(err.Error()) + if len(matches) != 3 { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + necessaryByteCount, atoiErr := strconv.Atoi(matches[1]) + if atoiErr != nil { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + return fmt.Errorf("%w: set flag --max-message-size=%d to increase the maximum allowable size", err, 2*necessaryByteCount) +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/cmd.go b/vendor/github.com/authzed/zed/internal/cmd/cmd.go new file mode 100644 index 0000000..a2e5332 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/cmd.go @@ -0,0 +1,187 @@ +package cmd + +import ( + "context" + "errors" + "io" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobrazerolog" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "github.com/authzed/zed/internal/commands" +) + +var ( + SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED") + errParsing = errors.New("parsing error") +) + +func init() { + // NOTE: this is mostly to set up logging in the case where + // the command doesn't exist or the construction of the command + // errors out before the PersistentPreRunE setup in the below function. + // It helps keep log output visually consistent for a user even in + // exceptional cases. + var output io.Writer + + if isatty.IsTerminal(os.Stdout.Fd()) { + output = zerolog.ConsoleWriter{Out: os.Stderr} + } else { + output = os.Stderr + } + + l := zerolog.New(output).With().Timestamp().Logger() + + log.Logger = l +} + +var flagError = flagErrorFunc + +func flagErrorFunc(cmd *cobra.Command, err error) error { + cmd.Println(err) + cmd.Println(cmd.UsageString()) + return errParsing +} + +// InitialiseRootCmd This function is utilised to generate docs for zed +func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { + rootCmd := &cobra.Command{ + Use: "zed", + Short: "SpiceDB CLI, built by AuthZed", + Long: "A command-line client for managing SpiceDB clusters.", + PersistentPreRunE: cobrautil.CommandStack( + zl.RunE(), + SyncFlagsCmdFunc, + commands.InjectRequestID, + ), + SilenceErrors: true, + SilenceUsage: true, + } + rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error { + return flagError(command, err) + }) + + zl.RegisterFlags(rootCmd.PersistentFlags()) + + rootCmd.PersistentFlags().String("endpoint", "", "spicedb gRPC API endpoint") + rootCmd.PersistentFlags().String("permissions-system", "", "permissions system to query") + rootCmd.PersistentFlags().String("hostname-override", "", "override the hostname used in the connection to the endpoint") + rootCmd.PersistentFlags().String("token", "", "token used to authenticate to SpiceDB") + rootCmd.PersistentFlags().String("certificate-path", "", "path to certificate authority used to verify secure connections") + rootCmd.PersistentFlags().Bool("insecure", false, "connect over a plaintext connection") + rootCmd.PersistentFlags().Bool("skip-version-check", false, "if true, no version check is performed against the server") + rootCmd.PersistentFlags().Bool("no-verify-ca", false, "do not attempt to verify the server's certificate chain and host name") + rootCmd.PersistentFlags().Bool("debug", false, "enable debug logging") + rootCmd.PersistentFlags().String("request-id", "", "optional id to send along with SpiceDB requests for tracing") + rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed") + rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address") + rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails") + _ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error. + + versionCmd := &cobra.Command{ + Use: "version", + Short: "Display zed and SpiceDB version information", + RunE: versionCmdFunc, + } + cobrautil.RegisterVersionFlags(versionCmd.Flags()) + versionCmd.Flags().Bool("include-remote-version", true, "whether to display the version of Authzed or SpiceDB for the current context") + rootCmd.AddCommand(versionCmd) + + // Register root-level aliases + rootCmd.AddCommand(&cobra.Command{ + Use: "use <context>", + Short: "Alias for `zed context use`", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: contextUseCmdFunc, + ValidArgsFunction: ContextGet, + }) + + // Register CLI-only commands. + registerContextCmd(rootCmd) + registerImportCmd(rootCmd) + registerValidateCmd(rootCmd) + registerBackupCmd(rootCmd) + registerPreviewCmd(rootCmd) + + // Register shared commands. + commands.RegisterPermissionCmd(rootCmd) + + relCmd := commands.RegisterRelationshipCmd(rootCmd) + + commands.RegisterWatchCmd(rootCmd) + commands.RegisterWatchRelationshipCmd(relCmd) + + schemaCmd := commands.RegisterSchemaCmd(rootCmd) + registerAdditionalSchemaCmds(schemaCmd) + + return rootCmd +} + +func Run() { + if err := runWithoutExit(); err != nil { + os.Exit(1) + } +} + +func runWithoutExit() error { + zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel)) + + rootCmd := InitialiseRootCmd(zl) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + signalChan := make(chan os.Signal, 2) + signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) + defer func() { + signal.Stop(signalChan) + cancel() + }() + + go func() { + select { + case <-signalChan: + cancel() + case <-ctx.Done(): + } + }() + + return handleError(rootCmd, rootCmd.ExecuteContext(ctx)) +} + +func handleError(command *cobra.Command, err error) error { + if err == nil { + return nil + } + // this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately + // parsed. This is necessary to be able to print the proper command-specific usage + var findErr error + var cmdToExecute *cobra.Command + args := os.Args[1:] + if command.TraverseChildren { + cmdToExecute, _, findErr = command.Traverse(args) + } else { + cmdToExecute, _, findErr = command.Find(args) + } + if findErr != nil { + cmdToExecute = command + } + + if errors.Is(err, commands.ValidationError{}) { + _ = flagError(cmdToExecute, err) + } else if err != nil && strings.Contains(err.Error(), "unknown command") { + _ = flagError(cmdToExecute, err) + } else if !errors.Is(err, errParsing) { + log.Err(err).Msg("terminated with errors") + } + + return err +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/context.go b/vendor/github.com/authzed/zed/internal/cmd/context.go new file mode 100644 index 0000000..d5a0eec --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/context.go @@ -0,0 +1,202 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/spf13/cobra" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/printers" + "github.com/authzed/zed/internal/storage" +) + +func registerContextCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(contextCmd) + + contextCmd.AddCommand(contextListCmd) + contextListCmd.Flags().Bool("reveal-tokens", false, "display secrets in results") + + contextCmd.AddCommand(contextSetCmd) + contextCmd.AddCommand(contextRemoveCmd) + contextCmd.AddCommand(contextUseCmd) +} + +var contextCmd = &cobra.Command{ + Use: "context <subcommand>", + Short: "Manage configurations for connecting to SpiceDB deployments", + Aliases: []string{"ctx"}, +} + +var contextListCmd = &cobra.Command{ + Use: "list", + Short: "Lists all available contexts", + Aliases: []string{"ls"}, + Args: commands.ValidationWrapper(cobra.ExactArgs(0)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: contextListCmdFunc, +} + +var contextSetCmd = &cobra.Command{ + Use: "set <name> <endpoint> <api-token>", + Short: "Creates or overwrite a context", + Args: commands.ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: contextSetCmdFunc, +} + +var contextRemoveCmd = &cobra.Command{ + Use: "remove <system>", + Short: "Removes a context", + Aliases: []string{"rm"}, + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + ValidArgsFunction: ContextGet, + RunE: contextRemoveCmdFunc, +} + +var contextUseCmd = &cobra.Command{ + Use: "use <system>", + Short: "Sets a context as the current context", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + ValidArgsFunction: ContextGet, + RunE: contextUseCmdFunc, +} + +func ContextGet(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + _, secretStore := client.DefaultStorage() + secrets, err := secretStore.Get() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + names := make([]string, 0, len(secrets.Tokens)) + for _, token := range secrets.Tokens { + names = append(names, token.Name) + } + + return names, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder +} + +func contextListCmdFunc(cmd *cobra.Command, _ []string) error { + cfgStore, secretStore := client.DefaultStorage() + secrets, err := secretStore.Get() + if err != nil { + return err + } + + cfg, err := cfgStore.Get() + if err != nil { + return err + } + + rows := make([][]string, 0, len(secrets.Tokens)) + for _, token := range secrets.Tokens { + current := "" + if token.Name == cfg.CurrentToken { + current = " ✓ " + } + secret := token.APIToken + if !cobrautil.MustGetBool(cmd, "reveal-tokens") { + secret = token.Redacted() + } + + var certStr string + if token.IsInsecure() { + certStr = "insecure" + } else if token.HasNoVerifyCA() { + certStr = "no-verify-ca" + } else if _, ok := token.Certificate(); ok { + certStr = "custom" + } else { + certStr = "system" + } + + rows = append(rows, []string{ + current, + token.Name, + token.Endpoint, + secret, + certStr, + }) + } + + printers.PrintTable(os.Stdout, []string{"current", "name", "endpoint", "token", "tls cert"}, rows) + + return nil +} + +func contextSetCmdFunc(cmd *cobra.Command, args []string) error { + var name, endpoint, apiToken string + err := stringz.Unpack(args, &name, &endpoint, &apiToken) + if err != nil { + return err + } + + certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path") + var certBytes []byte + if certPath != "" { + certBytes, err = os.ReadFile(certPath) + if err != nil { + return fmt.Errorf("failed to read ceritficate: %w", err) + } + } + + insecure := cobrautil.MustGetBool(cmd, "insecure") + noVerifyCA := cobrautil.MustGetBool(cmd, "no-verify-ca") + cfgStore, secretStore := client.DefaultStorage() + err = storage.PutToken(storage.Token{ + Name: name, + Endpoint: stringz.DefaultEmpty(endpoint, "grpc.authzed.com:443"), + APIToken: apiToken, + Insecure: &insecure, + NoVerifyCA: &noVerifyCA, + CACert: certBytes, + }, secretStore) + if err != nil { + return err + } + + return storage.SetCurrentToken(name, cfgStore, secretStore) +} + +func contextRemoveCmdFunc(_ *cobra.Command, args []string) error { + // If the token is what's currently being used, remove it from the config. + cfgStore, secretStore := client.DefaultStorage() + cfg, err := cfgStore.Get() + if err != nil { + return err + } + + if cfg.CurrentToken == args[0] { + cfg.CurrentToken = "" + } + + err = cfgStore.Put(cfg) + if err != nil { + return err + } + + return storage.RemoveToken(args[0], secretStore) +} + +func contextUseCmdFunc(_ *cobra.Command, args []string) error { + cfgStore, secretStore := client.DefaultStorage() + switch len(args) { + case 0: + cfg, err := cfgStore.Get() + if err != nil { + return err + } + console.Println(cfg.CurrentToken) + case 1: + return storage.SetCurrentToken(args[0], cfgStore, secretStore) + default: + panic("cobra command did not enforce valid number of args") + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/import.go b/vendor/github.com/authzed/zed/internal/cmd/import.go new file mode 100644 index 0000000..687d42e --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/import.go @@ -0,0 +1,181 @@ +package cmd + +import ( + "bufio" + "context" + "fmt" + "net/url" + "strings" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/validationfile" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/grpcutil" +) + +func registerImportCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(importCmd) + importCmd.Flags().Int("batch-size", 1000, "import batch size") + importCmd.Flags().Int("workers", 1, "number of concurrent batching workers") + importCmd.Flags().Bool("schema", true, "import schema") + importCmd.Flags().Bool("relationships", true, "import relationships") + importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing") +} + +var importCmd = &cobra.Command{ + Use: "import <url>", + Short: "Imports schema and relationships from a file or url", + Example: ` + From a gist: + zed import https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6 + + From a playground link: + zed import https://play.authzed.com/s/iksdFvCtvnkR/schema + + From pastebin: + zed import https://pastebin.com/8qU45rVK + + From a devtools instance: + zed import https://localhost:8443/download + + From a local file (with prefix): + zed import file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + From a local file (no prefix): + zed import authzed-x7izWU8_2Gw3.yaml + + Only schema: + zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + Only relationships: + zed import --schema=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + With schema definition prefix: + zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml +`, + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: importCmdFunc, +} + +func importCmdFunc(cmd *cobra.Command, args []string) error { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + u, err := url.Parse(args[0]) + if err != nil { + return err + } + + decoder, err := decode.DecoderForURL(u) + if err != nil { + return err + } + var p validationfile.ValidationFile + if _, _, err := decoder(&p); err != nil { + return err + } + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "schema") { + if err := importSchema(cmd.Context(), client, p.Schema.Schema, prefix); err != nil { + return err + } + } + + if cobrautil.MustGetBool(cmd, "relationships") { + batchSize := cobrautil.MustGetInt(cmd, "batch-size") + workers := cobrautil.MustGetInt(cmd, "workers") + if err := importRelationships(cmd.Context(), client, p.Relationships.RelationshipsString, prefix, batchSize, workers); err != nil { + return err + } + } + + return err +} + +func importSchema(ctx context.Context, client client.Client, schema string, definitionPrefix string) error { + log.Info().Msg("importing schema") + + // Recompile the schema with the specified prefix. + schemaText, err := rewriteSchema(schema, definitionPrefix) + if err != nil { + return err + } + + // Write the recompiled and regenerated schema. + request := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema") + + if _, err := client.WriteSchema(ctx, request); err != nil { + return err + } + + return nil +} + +func importRelationships(ctx context.Context, client client.Client, relationships string, definitionPrefix string, batchSize int, workers int) error { + relationshipUpdates := make([]*v1.RelationshipUpdate, 0) + scanner := bufio.NewScanner(strings.NewReader(relationships)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "//") { + continue + } + rel, err := tuple.ParseV1Rel(line) + if err != nil { + return fmt.Errorf("failed to parse %s as relationship", line) + } + log.Trace().Str("line", line).Send() + + // Rewrite the prefix on the references, if any. + if len(definitionPrefix) > 0 { + rel.Resource.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Resource.ObjectType) + rel.Subject.Object.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Subject.Object.ObjectType) + } + + relationshipUpdates = append(relationshipUpdates, &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: rel, + }) + } + if err := scanner.Err(); err != nil { + return err + } + + log.Info(). + Int("batch_size", batchSize). + Int("workers", workers). + Int("count", len(relationshipUpdates)). + Msg("importing relationships") + + err := grpcutil.ConcurrentBatch(ctx, len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error { + request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]} + _, err := client.WriteRelationships(ctx, request) + if err != nil { + return err + } + + log.Info(). + Int("batch_no", no). + Int("count", len(relationshipUpdates[start:end])). + Msg("imported relationships") + return nil + }) + return err +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/preview.go b/vendor/github.com/authzed/zed/internal/cmd/preview.go new file mode 100644 index 0000000..279e310 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/preview.go @@ -0,0 +1,120 @@ +package cmd + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/cobrautil/v2" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/term" + + newcompiler "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + newinput "github.com/authzed/spicedb/pkg/composableschemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + + "github.com/authzed/zed/internal/commands" +) + +func registerPreviewCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(previewCmd) + + previewCmd.AddCommand(schemaCmd) + + schemaCmd.AddCommand(schemaCompileCmd) + schemaCompileCmd.Flags().String("out", "", "output filepath; omitting writes to stdout") +} + +var previewCmd = &cobra.Command{ + Use: "preview <subcommand>", + Short: "Experimental commands that have been made available for preview", +} + +var schemaCmd = &cobra.Command{ + Use: "schema <subcommand>", + Short: "Manage schema for a permissions system", +} + +var schemaCompileCmd = &cobra.Command{ + Use: "compile <file>", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB", + Example: ` + Write to stdout: + zed preview schema compile root.zed + Write to an output file: + zed preview schema compile root.zed --out compiled.zed + `, + ValidArgsFunction: commands.FileExtensionCompletions("zed"), + RunE: schemaCompileCmdFunc, +} + +// Compiles an input schema written in the new composable schema syntax +// and produces it as a fully-realized schema +func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error { + stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd())) + if err != nil { + return err + } + outputFilepath := cobrautil.MustGetString(cmd, "out") + if outputFilepath == "" && !term.IsTerminal(stdOutFd) { + return fmt.Errorf("must provide stdout or output file path") + } + + inputFilepath := args[0] + inputSourceFolder := filepath.Dir(inputFilepath) + var schemaBytes []byte + schemaBytes, err = os.ReadFile(inputFilepath) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file") + + if len(schemaBytes) == 0 { + return errors.New("attempted to compile empty schema") + } + + compiled, err := newcompiler.Compile(newcompiler.InputSchema{ + Source: newinput.Source(inputFilepath), + SchemaString: string(schemaBytes), + }, newcompiler.AllowUnprefixedObjectType(), + newcompiler.SourceFolder(inputSourceFolder)) + if err != nil { + return err + } + + // Attempt to cast one kind of OrderedDefinition to another + oldDefinitions := make([]compiler.SchemaDefinition, 0, len(compiled.OrderedDefinitions)) + for _, definition := range compiled.OrderedDefinitions { + oldDefinition, ok := definition.(compiler.SchemaDefinition) + if !ok { + return fmt.Errorf("could not convert definition to old schemadefinition: %v", oldDefinition) + } + oldDefinitions = append(oldDefinitions, oldDefinition) + } + + // This is where we functionally assert that the two systems are compatible + generated, _, err := generator.GenerateSchema(oldDefinitions) + if err != nil { + return fmt.Errorf("could not generate resulting schema: %w", err) + } + + // Add a newline at the end for hygiene's sake + terminated := generated + "\n" + + if outputFilepath == "" { + // Print to stdout + fmt.Print(terminated) + } else { + err = os.WriteFile(outputFilepath, []byte(terminated), 0o_600) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/restorer.go b/vendor/github.com/authzed/zed/internal/cmd/restorer.go new file mode 100644 index 0000000..468064f --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/restorer.go @@ -0,0 +1,446 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/ccoveille/go-safecast" + "github.com/cenkalti/backoff/v4" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog/log" + "github.com/samber/lo" + "github.com/schollz/progressbar/v3" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/pkg/backupformat" +) + +type ConflictStrategy int + +const ( + Fail ConflictStrategy = iota + Skip + Touch + + defaultBackoff = 50 * time.Millisecond + defaultMaxRetries = 10 +) + +var conflictStrategyMapping = map[string]ConflictStrategy{ + "fail": Fail, + "skip": Skip, + "touch": Touch, +} + +// Fallback for datastore implementations on SpiceDB < 1.29.0 not returning proper gRPC codes +// Remove once https://github.com/authzed/spicedb/pull/1688 lands +var ( + txConflictCodes = []string{ + "SQLSTATE 23505", // CockroachDB + "Error 1062 (23000)", // MySQL + } + retryableErrorCodes = []string{ + "retryable error", // CockroachDB, PostgreSQL + "try restarting transaction", "Error 1205", // MySQL + } +) + +type restorer struct { + schema string + decoder *backupformat.Decoder + client client.Client + prefixFilter string + batchSize uint + batchesPerTransaction uint + conflictStrategy ConflictStrategy + disableRetryErrors bool + bar *progressbar.ProgressBar + + // stats + filteredOutRels uint + writtenRels uint + writtenBatches uint + skippedRels uint + skippedBatches uint + duplicateRels uint + duplicateBatches uint + totalRetries uint + requestTimeout time.Duration +} + +func newRestorer(schema string, decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize uint, + batchesPerTransaction uint, conflictStrategy ConflictStrategy, disableRetryErrors bool, + requestTimeout time.Duration, +) *restorer { + return &restorer{ + decoder: decoder, + schema: schema, + client: client, + prefixFilter: prefixFilter, + requestTimeout: requestTimeout, + batchSize: batchSize, + batchesPerTransaction: batchesPerTransaction, + conflictStrategy: conflictStrategy, + disableRetryErrors: disableRetryErrors, + bar: console.CreateProgressBar("restoring from backup"), + } +} + +func (r *restorer) restoreFromDecoder(ctx context.Context) error { + relationshipWriteStart := time.Now() + defer func() { + if err := r.bar.Finish(); err != nil { + log.Warn().Err(err).Msg("error finalizing progress bar") + } + }() + + r.bar.Describe("restoring schema from backup") + if _, err := r.client.WriteSchema(ctx, &v1.WriteSchemaRequest{ + Schema: r.schema, + }); err != nil { + return fmt.Errorf("unable to write schema: %w", err) + } + + relationshipWriter, err := r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating writer stream: %w", err) + } + + r.bar.Describe("restoring relationships from backup") + batch := make([]*v1.Relationship, 0, r.batchSize) + batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction) + for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() { + if err := ctx.Err(); err != nil { + r.bar.Describe("backup restore aborted") + return fmt.Errorf("aborted restore: %w", err) + } + + if !hasRelPrefix(rel, r.prefixFilter) { + r.filteredOutRels++ + continue + } + + batch = append(batch, rel) + + if uint(len(batch))%r.batchSize == 0 { + batchesToBeCommitted = append(batchesToBeCommitted, batch) + err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{ + Relationships: batch, + }) + if err != nil { + // It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element + // sent over the stream fails, we need to call recvAndClose() to get the error. + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing batches: %w", err) + } + + // after an error + relationshipWriter, err = r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating new writer stream: %w", err) + } + + batchesToBeCommitted = batchesToBeCommitted[:0] + batch = batch[:0] + continue + } + + // The batch just sent is kept in batchesToBeCommitted, which is used for retries. + // Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv). + batch = make([]*v1.Relationship, 0, r.batchSize) + + // if we've sent the maximum number of batches per transaction, proceed to commit + if uint(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 { + continue + } + + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing batches: %w", err) + } + + relationshipWriter, err = r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating new writer stream: %w", err) + } + + batchesToBeCommitted = batchesToBeCommitted[:0] + } + } + + // Write the last batch + if len(batch) > 0 { + // Since we are going to close the stream anyway after the last batch, and given the actual error + // is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual + // underlying error that caused Send() to fail. It also gives us the opportunity to retry it + // in case it failed. + batchesToBeCommitted = append(batchesToBeCommitted, batch) + _ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch}) + } + + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing last set of batches: %w", err) + } + + r.bar.Describe("completed import") + if err := r.bar.Finish(); err != nil { + log.Warn().Err(err).Msg("error finalizing progress bar") + } + + totalTime := time.Since(relationshipWriteStart) + log.Info(). + Uint("batches", r.writtenBatches). + Uint("relationships_loaded", r.writtenRels). + Uint("relationships_skipped", r.skippedRels). + Uint("duplicate_relationships", r.duplicateRels). + Uint("relationships_filtered_out", r.filteredOutRels). + Uint("retried_errors", r.totalRetries). + Uint64("perSecond", perSec(uint64(r.writtenRels), totalTime)). + Stringer("duration", totalTime). + Msg("finished restore") + return nil +} + +func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.ExperimentalService_BulkImportRelationshipsClient, + batchesToBeCommitted [][]*v1.Relationship, +) error { + var numLoaded, expectedLoaded, retries uint + for _, b := range batchesToBeCommitted { + expectedLoaded += uint(len(b)) + } + + resp, err := bulkImportClient.CloseAndRecv() // transaction commit happens here + + // Failure to commit transaction means the stream is closed, so it can't be reused any further + // The retry will be done using WriteRelationships instead of BulkImportRelationships + // This lets us retry with TOUCH semantics in case of failure due to duplicates + retryable := isRetryableError(err) + conflict := isAlreadyExistsError(err) + canceled, cancelErr := isCanceledError(ctx.Err(), err) + unknown := !retryable && !conflict && !canceled && err != nil + + numBatches := uint(len(batchesToBeCommitted)) + + switch { + case canceled: + r.bar.Describe("backup restore aborted") + return cancelErr + case unknown: + r.bar.Describe("failed with unrecoverable error") + return fmt.Errorf("error finalizing write of %d batches: %w", len(batchesToBeCommitted), err) + case retryable && r.disableRetryErrors: + return err + case conflict && r.conflictStrategy == Skip: + r.skippedRels += expectedLoaded + r.skippedBatches += numBatches + r.duplicateBatches += numBatches + r.duplicateRels += expectedLoaded + r.bar.Describe("skipping conflicting batch") + case conflict && r.conflictStrategy == Touch: + r.bar.Describe("touching conflicting batch") + r.duplicateRels += expectedLoaded + r.duplicateBatches += numBatches + r.totalRetries++ + numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) + if err != nil { + return fmt.Errorf("failed to write retried batch: %w", err) + } + + retries++ // account for the initial attempt + r.writtenBatches += numBatches + r.writtenRels += numLoaded + case conflict && r.conflictStrategy == Fail: + r.bar.Describe("conflict detected, aborting restore") + return fmt.Errorf("duplicate relationships found") + case retryable: + r.bar.Describe("retrying after error") + r.totalRetries++ + numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) + if err != nil { + return fmt.Errorf("failed to write retried batch: %w", err) + } + + retries++ // account for the initial attempt + r.writtenBatches += numBatches + r.writtenRels += numLoaded + default: + r.bar.Describe("restoring relationships from backup") + r.writtenBatches += numBatches + } + + // it was a successful transaction commit without duplicates + if resp != nil { + numLoaded, err := safecast.ToUint(resp.NumLoaded) + if err != nil { + return spiceerrors.MustBugf("could not cast numLoaded to uint") + } + r.writtenRels += numLoaded + if uint64(expectedLoaded) != resp.NumLoaded { + log.Warn().Uint64("loaded", resp.NumLoaded).Uint("expected", expectedLoaded).Msg("unexpected number of relationships loaded") + } + } + + writtenAndSkipped, err := safecast.ToInt64(r.writtenRels + r.skippedRels) + if err != nil { + return fmt.Errorf("too many written and skipped rels for an int64") + } + + if err := r.bar.Set64(writtenAndSkipped); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + if !isatty.IsTerminal(os.Stderr.Fd()) { + log.Trace(). + Uint("batches_written", r.writtenBatches). + Uint("relationships_written", r.writtenRels). + Uint("duplicate_batches", r.duplicateBatches). + Uint("duplicate_relationships", r.duplicateRels). + Uint("skipped_batches", r.skippedBatches). + Uint("skipped_relationships", r.skippedRels). + Uint("retries", retries). + Msg("restore progress") + } + + return nil +} + +// writeBatchesWithRetry writes a set of batches using touch semantics and without transactional guarantees - +// each batch will be committed independently. If a batch fails, it will be retried up to 10 times with a backoff. +func (r *restorer) writeBatchesWithRetry(ctx context.Context, batches [][]*v1.Relationship) (uint, uint, error) { + backoffInterval := backoff.NewExponentialBackOff() + backoffInterval.InitialInterval = defaultBackoff + backoffInterval.MaxInterval = 2 * time.Second + backoffInterval.MaxElapsedTime = 0 + backoffInterval.Reset() + + var currentRetries, totalRetries, loadedRels uint + for _, batch := range batches { + updates := lo.Map[*v1.Relationship, *v1.RelationshipUpdate](batch, func(item *v1.Relationship, _ int) *v1.RelationshipUpdate { + return &v1.RelationshipUpdate{ + Relationship: item, + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + } + }) + + for { + cancelCtx, cancel := context.WithTimeout(ctx, r.requestTimeout) + _, err := r.client.WriteRelationships(cancelCtx, &v1.WriteRelationshipsRequest{Updates: updates}) + cancel() + + if isRetryableError(err) && currentRetries < defaultMaxRetries { + // throttle the writes so we don't overwhelm the server + bo := backoffInterval.NextBackOff() + r.bar.Describe(fmt.Sprintf("retrying write with backoff %s after error (attempt %d/%d)", bo, + currentRetries+1, defaultMaxRetries)) + time.Sleep(bo) + currentRetries++ + r.totalRetries++ + totalRetries++ + continue + } + if err != nil { + return 0, 0, err + } + + currentRetries = 0 + backoffInterval.Reset() + loadedRels += uint(len(batch)) + break + } + } + + return loadedRels, totalRetries, nil +} + +func isAlreadyExistsError(err error) bool { + if err == nil { + return false + } + + if isGRPCCode(err, codes.AlreadyExists) { + return true + } + + return isContainsErrorString(err, txConflictCodes...) +} + +func isRetryableError(err error) bool { + if err == nil { + return false + } + + if isGRPCCode(err, codes.Unavailable, codes.DeadlineExceeded) { + return true + } + + if isContainsErrorString(err, retryableErrorCodes...) { + return true + } + + return errors.Is(err, context.DeadlineExceeded) +} + +func isCanceledError(errs ...error) (bool, error) { + for _, err := range errs { + if err == nil { + continue + } + + if errors.Is(err, context.Canceled) { + return true, err + } + + if isGRPCCode(err, codes.Canceled) { + return true, err + } + } + + return false, nil +} + +func isContainsErrorString(err error, errStrings ...string) bool { + if err == nil { + return false + } + + for _, errString := range errStrings { + if strings.Contains(err.Error(), errString) { + return true + } + } + + return false +} + +func isGRPCCode(err error, codes ...codes.Code) bool { + if err == nil { + return false + } + + if s, ok := status.FromError(err); ok { + for _, code := range codes { + if s.Code() == code { + return true + } + } + } + + return false +} + +func perSec(i uint64, d time.Duration) uint64 { + secs := uint64(d.Seconds()) + if secs == 0 { + return i + } + return i / secs +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/schema.go b/vendor/github.com/authzed/zed/internal/cmd/schema.go new file mode 100644 index 0000000..bbe52f3 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/schema.go @@ -0,0 +1,319 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/term" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/diff" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" +) + +func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { + schemaCmd.AddCommand(schemaCopyCmd) + schemaCopyCmd.Flags().Bool("json", false, "output as JSON") + schemaCopyCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") + + schemaCmd.AddCommand(schemaWriteCmd) + schemaWriteCmd.Flags().Bool("json", false, "output as JSON") + schemaWriteCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") + + schemaCmd.AddCommand(schemaDiffCmd) +} + +var schemaWriteCmd = &cobra.Command{ + Use: "write <file?>", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + Short: "Write a schema file (.zed or stdin) to the current permissions system", + ValidArgsFunction: commands.FileExtensionCompletions("zed"), + RunE: schemaWriteCmdFunc, +} + +var schemaCopyCmd = &cobra.Command{ + Use: "copy <src context> <dest context>", + Short: "Copy a schema from one context into another", + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), + ValidArgsFunction: ContextGet, + RunE: schemaCopyCmdFunc, +} + +var schemaDiffCmd = &cobra.Command{ + Use: "diff <before file> <after file>", + Short: "Diff two schema files", + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), + RunE: schemaDiffCmdFunc, +} + +func schemaDiffCmdFunc(_ *cobra.Command, args []string) error { + beforeBytes, err := os.ReadFile(args[0]) + if err != nil { + return fmt.Errorf("failed to read before schema file: %w", err) + } + + afterBytes, err := os.ReadFile(args[1]) + if err != nil { + return fmt.Errorf("failed to read after schema file: %w", err) + } + + before, err := compiler.Compile( + compiler.InputSchema{Source: input.Source(args[0]), SchemaString: string(beforeBytes)}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return err + } + + after, err := compiler.Compile( + compiler.InputSchema{Source: input.Source(args[1]), SchemaString: string(afterBytes)}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return err + } + + dbefore := diff.NewDiffableSchemaFromCompiledSchema(before) + dafter := diff.NewDiffableSchemaFromCompiledSchema(after) + + schemaDiff, err := diff.DiffSchemas(dbefore, dafter, types.Default.TypeSet) + if err != nil { + return err + } + + for _, ns := range schemaDiff.AddedNamespaces { + console.Printf("Added definition: %s\n", ns) + } + + for _, ns := range schemaDiff.RemovedNamespaces { + console.Printf("Removed definition: %s\n", ns) + } + + for nsName, ns := range schemaDiff.ChangedNamespaces { + console.Printf("Changed definition: %s\n", nsName) + for _, delta := range ns.Deltas() { + console.Printf("\t %s: %s\n", delta.Type, delta.RelationName) + } + } + + for _, caveat := range schemaDiff.AddedCaveats { + console.Printf("Added caveat: %s\n", caveat) + } + + for _, caveat := range schemaDiff.RemovedCaveats { + console.Printf("Removed caveat: %s\n", caveat) + } + + return nil +} + +func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { + _, secretStore := client.DefaultStorage() + srcClient, err := client.NewClientForContext(cmd, args[0], secretStore) + if err != nil { + return err + } + + destClient, err := client.NewClientForContext(cmd, args[1], secretStore) + if err != nil { + return err + } + + readRequest := &v1.ReadSchemaRequest{} + log.Trace().Interface("request", readRequest).Msg("requesting schema read") + + readResp, err := srcClient.ReadSchema(cmd.Context(), readRequest) + if err != nil { + log.Fatal().Err(err).Msg("failed to read schema") + } + log.Trace().Interface("response", readResp).Msg("read schema") + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), nil, &readResp.SchemaText) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(readResp.SchemaText, prefix) + if err != nil { + return err + } + + writeRequest := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", writeRequest).Msg("writing schema") + + resp, err := destClient.WriteSchema(cmd.Context(), writeRequest) + if err != nil { + log.Fatal().Err(err).Msg("failed to write schema") + } + log.Trace().Interface("response", resp).Msg("wrote schema") + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := commands.PrettyProto(resp) + if err != nil { + log.Fatal().Err(err).Msg("failed to convert schema to JSON") + } + + console.Println(string(prettyProto)) + return nil + } + + return nil +} + +func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error { + intFd, err := safecast.ToInt(uint(os.Stdout.Fd())) + if err != nil { + return err + } + if len(args) == 0 && term.IsTerminal(intFd) { + return fmt.Errorf("must provide file path or contents via stdin") + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + var schemaBytes []byte + switch len(args) { + case 1: + schemaBytes, err = os.ReadFile(args[0]) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file") + case 0: + schemaBytes, err = io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Msg("read schema from stdin") + default: + panic("schemaWriteCmdFunc called with incorrect number of arguments") + } + + if len(schemaBytes) == 0 { + return errors.New("attempted to write empty schema") + } + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(string(schemaBytes), prefix) + if err != nil { + return err + } + + request := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", request).Msg("writing schema") + + resp, err := client.WriteSchema(cmd.Context(), request) + if err != nil { + log.Fatal().Err(err).Msg("failed to write schema") + } + log.Trace().Interface("response", resp).Msg("wrote schema") + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := commands.PrettyProto(resp) + if err != nil { + log.Fatal().Err(err).Msg("failed to convert schema to JSON") + } + + console.Println(string(prettyProto)) + return nil + } + + return nil +} + +// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions. +func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, error) { + if definitionPrefix == "" { + return existingSchemaText, nil + } + + compiled, err := compiler.Compile( + compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText}, + compiler.ObjectTypePrefix(definitionPrefix), + compiler.SkipValidation(), + ) + if err != nil { + return "", err + } + + generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions) + return generated, err +} + +// determinePrefixForSchema determines the prefix to be applied to a schema that will be written. +// +// If specifiedPrefix is non-empty, it is returned immediately. +// If existingSchema is non-nil, it is parsed for the prefix. +// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there. +func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) { + if specifiedPrefix != "" { + return specifiedPrefix, nil + } + + var schemaText string + if existingSchema != nil { + schemaText = *existingSchema + } else { + readSchemaText, err := commands.ReadSchema(ctx, client) + if err != nil { + return "", nil + } + schemaText = readSchemaText + } + + // If there is no schema found, return the empty string. + if schemaText == "" { + return "", nil + } + + // Otherwise, compile the schema and grab the prefixes of the namespaces defined. + found, err := compiler.Compile( + compiler.InputSchema{Source: input.Source("schema"), SchemaString: schemaText}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", err + } + + foundPrefixes := make([]string, 0, len(found.OrderedDefinitions)) + for _, def := range found.OrderedDefinitions { + if strings.Contains(def.GetName(), "/") { + parts := strings.Split(def.GetName(), "/") + foundPrefixes = append(foundPrefixes, parts[0]) + } else { + foundPrefixes = append(foundPrefixes, "") + } + } + + prefixes := stringz.Dedup(foundPrefixes) + if len(prefixes) == 1 { + prefix := prefixes[0] + log.Debug().Str("prefix", prefix).Msg("found schema definition prefix") + return prefix, nil + } + + return "", nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go new file mode 100644 index 0000000..ffca1c6 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go @@ -0,0 +1,414 @@ +package cmd + +import ( + "errors" + "fmt" + "net/url" + "os" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/charmbracelet/lipgloss" + "github.com/jzelinskie/cobrautil/v2" + "github.com/muesli/termenv" + "github.com/spf13/cobra" + + composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + "github.com/authzed/spicedb/pkg/development" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/printers" +) + +var ( + // NOTE: these need to be set *after* the renderer has been set, otherwise + // the forceColor setting can't work, hence the thunking. + success = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!") + } + complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") } + errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") } + warningPrefix = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ") + } + errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) } + linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) } + highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) } + highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) } + traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) } +) + +func registerValidateCmd(cmd *cobra.Command) { + validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments") + validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")") + cmd.AddCommand(validateCmd) +} + +var validateCmd = &cobra.Command{ + Use: "validate <validation_file_or_schema_file>", + Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)", + Example: ` + From a local file (with prefix): + zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + From a local file (no prefix): + zed validate authzed-x7izWU8_2Gw3.yaml + + From a gist: + zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6 + + From a playground link: + zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema + + From pastebin: + zed validate https://pastebin.com/8qU45rVK + + From a devtools instance: + zed validate https://localhost:8443/download`, + Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)), + ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"), + PreRunE: validatePreRunE, + RunE: func(cmd *cobra.Command, filenames []string) error { + result, shouldExit, err := validateCmdFunc(cmd, filenames) + if err != nil { + return err + } + console.Print(result) + if shouldExit { + os.Exit(1) + } + return nil + }, + + // A schema that causes the parser/compiler to error will halt execution + // of this command with an error. In that case, we want to just display the error, + // rather than showing usage for this command. + SilenceUsage: true, +} + +var validSchemaTypes = []string{"", "standard", "composable"} + +func validatePreRunE(cmd *cobra.Command, _ []string) error { + // Override lipgloss's autodetection of whether it's in a terminal environment + // and display things in color anyway. This can be nice in CI environments that + // support it. + setForceColor := cobrautil.MustGetBool(cmd, "force-color") + if setForceColor { + lipgloss.SetColorProfile(termenv.ANSI256) + } + + schemaType := cobrautil.MustGetString(cmd, "schema-type") + schemaTypeValid := false + for _, validType := range validSchemaTypes { + if schemaType == validType { + schemaTypeValid = true + } + } + if !schemaTypeValid { + return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType) + } + + return nil +} + +// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors. +func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) { + // Initialize variables for multiple files + var ( + totalFiles = len(filenames) + successfullyValidatedFiles = 0 + shouldExit = false + toPrint = &strings.Builder{} + schemaType = cobrautil.MustGetString(cmd, "schema-type") + ) + + for _, filename := range filenames { + // If we're running over multiple files, print the filename for context/debugging purposes + if totalFiles > 1 { + toPrint.WriteString(filename + "\n") + } + + u, err := url.Parse(filename) + if err != nil { + return "", false, err + } + + decoder, err := decode.DecoderForURL(u) + if err != nil { + return "", false, err + } + + var parsed validationfile.ValidationFile + validateContents, isOnlySchema, err := decoder(&parsed) + standardErrors, composableErrs, otherErrs := classifyErrors(err) + + switch schemaType { + case "standard": + if standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(standardErrors, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, standardErrors + } + case "composable": + if composableErrs != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + default: + // By default, validate will attempt to validate a schema first according to composable schema rules, + // then standard schema rules, + // and if both fail it will show the errors from composable schema. + if composableErrs != nil && standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + } + + if otherErrs != nil { + return "", false, otherErrs + } + + tuples := make([]*core.RelationTuple, 0) + totalAssertions := 0 + totalRelationsValidated := 0 + + for _, rel := range parsed.Relationships.Relationships { + tuples = append(tuples, rel.ToCoreTuple()) + } + + // Create the development context for each run + ctx := cmd.Context() + devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{ + Schema: parsed.Schema.Schema, + Relationships: tuples, + }) + if err != nil { + return "", false, err + } + if devErrs != nil { + schemaOffset := parsed.Schema.SourcePosition.LineNumber + if isOnlySchema { + schemaOffset = 0 + } + + // Output errors + outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset) + return toPrint.String(), true, nil + } + // Run assertions + adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions) + if aerr != nil { + return "", false, aerr + } + if adevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, adevErrs) + return toPrint.String(), true, nil + } + successfullyValidatedFiles++ + + // Run expected relations for file + _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations) + if rerr != nil { + return "", false, rerr + } + if erDevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, erDevErrs) + return toPrint.String(), true, nil + } + // Print out any warnings for file + warnings, err := development.GetWarnings(ctx, devCtx) + if err != nil { + return "", false, err + } + + if len(warnings) > 0 { + for _, warning := range warnings { + fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message) + outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed + toPrint.WriteString("\n") + } + + toPrint.WriteString(complete()) + } else { + toPrint.WriteString(success()) + } + totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse) + totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap) + + fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n", + len(tuples), + totalAssertions, + totalRelationsValidated) + } + + if totalFiles > 1 { + fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles) + } + + return toPrint.String(), shouldExit, nil +} + +func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) { + fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error())) + outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed +} + +func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) { + lines := strings.Split(string(validateContents), "\n") + // These should be fine to be zero if the cast fails. + intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber) + intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition) + errorLineNumber := intLineNumber - 1 + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } +} + +func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) { + outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0) +} + +func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) { + lines := strings.Split(string(validateContents), "\n") + + for _, devErr := range devErrors { + outputDeveloperError(sb, devErr, lines, lineOffset) + } +} + +func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) { + fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message)) + errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, devError.Context, errorLineNumber, -1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } + + if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil { + fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:")) + tp := printers.NewTreePrinter() + printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true) + sb.WriteString(tp.Indented()) + } + + sb.WriteString("\n\n") +} + +func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) { + if index < 0 || index >= len(lines) { + return + } + + lineNumberLength := len(fmt.Sprintf("%d", len(lines))) + lineContents := strings.ReplaceAll(lines[index], "\t", " ") + lineDelimiter := "|" + + highlightLength := max(0, len(highlight)-1) + highlightColumnIndex := -1 + + // If the highlight string was provided, then we need to find the index of the highlight + // string in the line contents to determine where to place the caret. + if len(highlight) > 0 { + offset := 0 + for { + foundRelativeIndex := strings.Index(lineContents[offset:], highlight) + foundIndex := foundRelativeIndex + offset + if foundIndex >= highlightStartingColumnIndex { + highlightColumnIndex = foundIndex + break + } + + offset = foundIndex + 1 + if foundRelativeIndex < 0 || foundIndex > len(lineContents) { + break + } + } + } else if highlightStartingColumnIndex >= 0 { + // Otherwise, just show a caret at the specified starting column, if any. + highlightColumnIndex = highlightStartingColumnIndex + } + + lineNumberStr := fmt.Sprintf("%d", index+1) + noNumberSpaces := strings.Repeat(" ", lineNumberLength) + + lineNumberStyle := linePrefixStyle + lineContentsStyle := codeStyle + if index == highlightLineIndex { + lineNumberStyle = highlightedLineStyle + lineContentsStyle = highlightedCodeStyle + lineDelimiter = ">" + } + + lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr)) + + if highlightColumnIndex < 0 { + fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents)) + } else { + fmt.Fprintf(sb, " %s%s %s %s%s%s\n", + lineNumberSpacer, + lineNumberStyle().Render(lineNumberStr), + lineDelimiter, + lineContentsStyle().Render(lineContents[0:highlightColumnIndex]), + highlightedSourceStyle().Render(highlight), + lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):])) + + fmt.Fprintf(sb, " %s %s %s%s%s\n", + noNumberSpaces, + lineDelimiter, + strings.Repeat(" ", highlightColumnIndex), + highlightedSourceStyle().Render("^"), + highlightedSourceStyle().Render(strings.Repeat("~", highlightLength))) + } +} + +// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors. +func classifyErrors(err error) (error, error, error) { + if err == nil { + return nil, nil, nil + } + var standardErr compiler.BaseCompilerError + var composableErr composable.BaseCompilerError + var retStandard, retComposable, allOthers error + + ok := errors.As(err, &standardErr) + if ok { + retStandard = standardErr + } + ok = errors.As(err, &composableErr) + if ok { + retComposable = composableErr + } + + if retStandard == nil && retComposable == nil { + allOthers = err + } + + return retStandard, retComposable, allOthers +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/version.go b/vendor/github.com/authzed/zed/internal/cmd/version.go new file mode 100644 index 0000000..b87ec66 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/version.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/gookit/color" + "github.com/jzelinskie/cobrautil/v2" + "github.com/mattn/go-isatty" + "github.com/spf13/cobra" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/authzed/authzed-go/pkg/responsemeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" +) + +func versionCmdFunc(cmd *cobra.Command, _ []string) error { + if !isatty.IsTerminal(os.Stdout.Fd()) { + color.Disable() + } + + includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version") + hasContext := false + if includeRemoteVersion { + configStore, secretStore := client.DefaultStorage() + _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + hasContext = err == nil + } + + if hasContext && includeRemoteVersion { + green := color.FgGreen.Render + fmt.Print(green("client: ")) + } + + console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps"))) + + if hasContext && includeRemoteVersion { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + + // NOTE: we ignore the error here, as it may be due to a schema not existing, or + // the client being unable to connect, etc. We just treat all such cases as an unknown + // version. + var headerMD metadata.MD + _, _ = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD)) + version := headerMD.Get(string(responsemeta.ServerVersion)) + + blue := color.FgLightBlue.Render + fmt.Print(blue("service: ")) + if len(version) == 1 { + console.Println(version[0]) + } else { + console.Println("(unknown)") + } + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/commands/completion.go b/vendor/github.com/authzed/zed/internal/commands/completion.go new file mode 100644 index 0000000..dd24a74 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/completion.go @@ -0,0 +1,165 @@ +package commands + +import ( + "errors" + "strings" + + "github.com/spf13/cobra" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + + "github.com/authzed/zed/internal/client" +) + +type CompletionArgumentType int + +const ( + ResourceType CompletionArgumentType = iota + ResourceID + Permission + SubjectType + SubjectID + SubjectTypeWithOptionalRelation +) + +func FileExtensionCompletions(extension ...string) func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + return extension, cobra.ShellCompDirectiveFilterFileExt + } +} + +func GetArgs(fields ...CompletionArgumentType) func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + // Read the current schema, if any. + schema, err := readSchema(cmd) + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + // Find the specified resource type, if any. + var resourceType string + loop: + for index, arg := range args { + field := fields[index] + switch field { + case ResourceType: + resourceType = arg + break loop + + case ResourceID: + pieces := strings.Split(arg, ":") + if len(pieces) >= 1 { + resourceType = pieces[0] + break loop + } + } + } + + // Handle : on resource and subject IDs. + if strings.HasSuffix(toComplete, ":") && (fields[len(args)] == ResourceID || fields[len(args)] == SubjectID) { + comps := []string{} + comps = cobra.AppendActiveHelp(comps, "Please enter an object ID") + return comps, cobra.ShellCompDirectiveNoFileComp + } + + // Handle # on subject types. If the toComplete contains a valid subject, + // then we should return the relation names. Note that we cannot do this + // on the # character because shell autocompletion won't send it to us. + if len(args) == len(fields)-1 && toComplete != "" && fields[len(args)] == SubjectTypeWithOptionalRelation { + for _, objDef := range schema.ObjectDefinitions { + subjectType := toComplete + if objDef.Name == subjectType { + relationNames := make([]string, 0) + relationNames = append(relationNames, subjectType) + for _, relation := range objDef.Relation { + relationNames = append(relationNames, subjectType+"#"+relation.Name) + } + return relationNames, cobra.ShellCompDirectiveNoFileComp + } + } + } + + if len(args) >= len(fields) { + // If we have all the arguments, return no completions. + return nil, cobra.ShellCompDirectiveNoFileComp + } + + // Return the completions. + currentFieldType := fields[len(args)] + switch currentFieldType { + case ResourceType: + fallthrough + + case SubjectType: + fallthrough + + case SubjectID: + fallthrough + + case SubjectTypeWithOptionalRelation: + fallthrough + + case ResourceID: + resourceTypeNames := make([]string, 0, len(schema.ObjectDefinitions)) + for _, objDef := range schema.ObjectDefinitions { + resourceTypeNames = append(resourceTypeNames, objDef.Name) + } + + flags := cobra.ShellCompDirectiveNoFileComp + if currentFieldType == ResourceID || currentFieldType == SubjectID || currentFieldType == SubjectTypeWithOptionalRelation { + flags |= cobra.ShellCompDirectiveNoSpace + } + + return resourceTypeNames, flags + + case Permission: + if resourceType == "" { + return nil, cobra.ShellCompDirectiveNoFileComp + } + + relationNames := make([]string, 0) + for _, objDef := range schema.ObjectDefinitions { + if objDef.Name == resourceType { + for _, relation := range objDef.Relation { + relationNames = append(relationNames, relation.Name) + } + } + } + return relationNames, cobra.ShellCompDirectiveNoFileComp + } + + return nil, cobra.ShellCompDirectiveDefault + } +} + +func readSchema(cmd *cobra.Command) (*compiler.CompiledSchema, error) { + // TODO: we should find a way to cache this + client, err := client.NewClient(cmd) + if err != nil { + return nil, err + } + + request := &v1.ReadSchemaRequest{} + + resp, err := client.ReadSchema(cmd.Context(), request) + if err != nil { + return nil, err + } + + schemaText := resp.SchemaText + if len(schemaText) == 0 { + return nil, errors.New("no schema defined") + } + + compiledSchema, err := compiler.Compile( + compiler.InputSchema{Source: "schema", SchemaString: schemaText}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return nil, err + } + + return compiledSchema, nil +} diff --git a/vendor/github.com/authzed/zed/internal/commands/permission.go b/vendor/github.com/authzed/zed/internal/commands/permission.go new file mode 100644 index 0000000..58b022c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/permission.go @@ -0,0 +1,673 @@ +package commands + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/encoding/prototext" + + "github.com/authzed/authzed-go/pkg/requestmeta" + "github.com/authzed/authzed-go/pkg/responsemeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/printers" +) + +var ErrMultipleConsistencies = errors.New("provided multiple consistency flags") + +func registerConsistencyFlags(flags *pflag.FlagSet) { + flags.String("consistency-at-exactly", "", "evaluate at the provided zedtoken") + flags.String("consistency-at-least", "", "evaluate at least as consistent as the provided zedtoken") + flags.Bool("consistency-min-latency", false, "evaluate at the zedtoken preferred by the database") + flags.Bool("consistency-full", false, "evaluate at the newest zedtoken in the database") +} + +func consistencyFromCmd(cmd *cobra.Command) (c *v1.Consistency, err error) { + if cobrautil.MustGetBool(cmd, "consistency-full") { + c = &v1.Consistency{Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}} + } + if atLeast := cobrautil.MustGetStringExpanded(cmd, "consistency-at-least"); atLeast != "" { + if c != nil { + return nil, ErrMultipleConsistencies + } + c = &v1.Consistency{Requirement: &v1.Consistency_AtLeastAsFresh{AtLeastAsFresh: &v1.ZedToken{Token: atLeast}}} + } + + // Deprecated (hidden) flag. + if revision := cobrautil.MustGetStringExpanded(cmd, "revision"); revision != "" { + if c != nil { + return nil, ErrMultipleConsistencies + } + c = &v1.Consistency{Requirement: &v1.Consistency_AtLeastAsFresh{AtLeastAsFresh: &v1.ZedToken{Token: revision}}} + } + + if exact := cobrautil.MustGetStringExpanded(cmd, "consistency-at-exactly"); exact != "" { + if c != nil { + return nil, ErrMultipleConsistencies + } + c = &v1.Consistency{Requirement: &v1.Consistency_AtExactSnapshot{AtExactSnapshot: &v1.ZedToken{Token: exact}}} + } + + if c == nil { + c = &v1.Consistency{Requirement: &v1.Consistency_MinimizeLatency{MinimizeLatency: true}} + } + return +} + +func RegisterPermissionCmd(rootCmd *cobra.Command) *cobra.Command { + rootCmd.AddCommand(permissionCmd) + + permissionCmd.AddCommand(checkCmd) + checkCmd.Flags().Bool("json", false, "output as JSON") + checkCmd.Flags().String("revision", "", "optional revision at which to check") + _ = checkCmd.Flags().MarkHidden("revision") + checkCmd.Flags().Bool("explain", false, "requests debug information from SpiceDB and prints out a trace of the requests") + checkCmd.Flags().Bool("schema", false, "requests debug information from SpiceDB and prints out the schema used") + checkCmd.Flags().Bool("error-on-no-permission", false, "if true, zed will return exit code 1 if subject does not have unconditional permission") + checkCmd.Flags().String("caveat-context", "", "the caveat context to send along with the check, in JSON form") + registerConsistencyFlags(checkCmd.Flags()) + + permissionCmd.AddCommand(checkBulkCmd) + checkBulkCmd.Flags().String("revision", "", "optional revision at which to check") + checkBulkCmd.Flags().Bool("json", false, "output as JSON") + checkBulkCmd.Flags().Bool("explain", false, "requests debug information from SpiceDB and prints out a trace of the requests") + checkBulkCmd.Flags().Bool("schema", false, "requests debug information from SpiceDB and prints out the schema used") + registerConsistencyFlags(checkBulkCmd.Flags()) + + permissionCmd.AddCommand(expandCmd) + expandCmd.Flags().Bool("json", false, "output as JSON") + expandCmd.Flags().String("revision", "", "optional revision at which to check") + registerConsistencyFlags(expandCmd.Flags()) + + // NOTE: `lookup` is an alias of `lookup-resources` (below) + // and must have the same list of flags in order for it to work. + permissionCmd.AddCommand(lookupCmd) + lookupCmd.Flags().Bool("json", false, "output as JSON") + lookupCmd.Flags().String("revision", "", "optional revision at which to check") + lookupCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form") + lookupCmd.Flags().Uint32("page-limit", 0, "limit of relations returned per page") + registerConsistencyFlags(lookupCmd.Flags()) + + permissionCmd.AddCommand(lookupResourcesCmd) + lookupResourcesCmd.Flags().Bool("json", false, "output as JSON") + lookupResourcesCmd.Flags().String("revision", "", "optional revision at which to check") + lookupResourcesCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form") + lookupResourcesCmd.Flags().Uint32("page-limit", 0, "limit of relations returned per page") + lookupResourcesCmd.Flags().String("cursor", "", "resume pagination from a specific cursor token") + lookupResourcesCmd.Flags().Bool("show-cursor", true, "display the cursor token after pagination") + registerConsistencyFlags(lookupResourcesCmd.Flags()) + + permissionCmd.AddCommand(lookupSubjectsCmd) + lookupSubjectsCmd.Flags().Bool("json", false, "output as JSON") + lookupSubjectsCmd.Flags().String("revision", "", "optional revision at which to check") + lookupSubjectsCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form") + registerConsistencyFlags(lookupSubjectsCmd.Flags()) + + return permissionCmd +} + +var permissionCmd = &cobra.Command{ + Use: "permission <subcommand>", + Short: "Query the permissions in a permissions system", + Aliases: []string{"perm"}, +} + +var checkBulkCmd = &cobra.Command{ + Use: "bulk <resource:id#permission@subject:id> <resource:id#permission@subject:id> ...", + Short: "Check a permissions in bulk exists for a resource-subject pairs", + Args: ValidationWrapper(cobra.MinimumNArgs(1)), + RunE: checkBulkCmdFunc, +} + +var checkCmd = &cobra.Command{ + Use: "check <resource:id> <permission> <subject:id>", + Short: "Check that a permission exists for a subject", + Args: ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectID), + RunE: checkCmdFunc, +} + +var expandCmd = &cobra.Command{ + Use: "expand <permission> <resource:id>", + Short: "Expand the structure of a permission", + Args: ValidationWrapper(cobra.ExactArgs(2)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: expandCmdFunc, +} + +var lookupResourcesCmd = &cobra.Command{ + Use: "lookup-resources <type> <permission> <subject:id>", + Short: "Enumerates resources of a given type for which the subject has permission", + Args: ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID), + RunE: lookupResourcesCmdFunc, +} + +var lookupCmd = &cobra.Command{ + Use: "lookup <type> <permission> <subject:id>", + Short: "Enumerates the resources of a given type for which the subject has permission", + Args: ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID), + RunE: lookupResourcesCmdFunc, + Deprecated: "prefer lookup-resources", + Hidden: true, +} + +var lookupSubjectsCmd = &cobra.Command{ + Use: "lookup-subjects <resource:id> <permission> <subject_type#optional_subject_relation>", + Short: "Enumerates the subjects of a given type for which the subject has permission on the resource", + Args: ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: lookupSubjectsCmdFunc, +} + +func checkCmdFunc(cmd *cobra.Command, args []string) error { + var objectNS, objectID string + err := stringz.SplitExact(args[0], ":", &objectNS, &objectID) + if err != nil { + return err + } + + relation := args[1] + + subjectNS, subjectID, subjectRel, err := ParseSubject(args[2]) + if err != nil { + return err + } + + caveatContext, err := GetCaveatContext(cmd) + if err != nil { + return err + } + + consistency, err := consistencyFromCmd(cmd) + if err != nil { + return err + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + + request := &v1.CheckPermissionRequest{ + Resource: &v1.ObjectReference{ + ObjectType: objectNS, + ObjectId: objectID, + }, + Permission: relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: subjectNS, + ObjectId: subjectID, + }, + OptionalRelation: subjectRel, + }, + Context: caveatContext, + Consistency: consistency, + } + log.Trace().Interface("request", request).Send() + + ctx := cmd.Context() + if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") { + log.Info().Msg("debugging requested on check") + ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestDebugInformation) + request.WithTracing = true + } + + var trailerMD metadata.MD + resp, err := client.CheckPermission(ctx, request, grpc.Trailer(&trailerMD)) + if err != nil { + var debugInfo *v1.DebugInformation + + // Check for the debug trace contained in the error details. + if errInfo, ok := grpcErrorInfoFrom(err); ok { + if encodedDebugInfo, ok := errInfo.Metadata["debug_trace_proto_text"]; ok { + debugInfo = &v1.DebugInformation{} + if uerr := prototext.Unmarshal([]byte(encodedDebugInfo), debugInfo); uerr != nil { + return uerr + } + } + } + + derr := displayDebugInformationIfRequested(cmd, debugInfo, trailerMD, true) + if derr != nil { + return derr + } + + return err + } + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + return nil + } + + switch resp.Permissionship { + case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION: + log.Warn().Strs("fields", resp.PartialCaveatInfo.MissingRequiredContext).Msg("missing fields in caveat context") + console.Println("caveated") + + case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION: + console.Println("true") + + case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION: + console.Println("false") + + default: + return fmt.Errorf("unknown permission response: %v", resp.Permissionship) + } + + err = displayDebugInformationIfRequested(cmd, resp.DebugTrace, trailerMD, false) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "error-on-no-permission") { + if resp.Permissionship != v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION { + os.Exit(1) + } + } + + return nil +} + +func checkBulkCmdFunc(cmd *cobra.Command, args []string) error { + items := make([]*v1.CheckBulkPermissionsRequestItem, 0, len(args)) + for _, arg := range args { + rel, err := tuple.ParseV1Rel(arg) + if err != nil { + return fmt.Errorf("unable to parse relation: %s", arg) + } + + item := &v1.CheckBulkPermissionsRequestItem{ + Resource: &v1.ObjectReference{ + ObjectType: rel.Resource.ObjectType, + ObjectId: rel.Resource.ObjectId, + }, + Permission: rel.Relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: rel.Subject.Object.ObjectType, + ObjectId: rel.Subject.Object.ObjectId, + }, + }, + } + if rel.OptionalCaveat != nil { + item.Context = rel.OptionalCaveat.Context + } + items = append(items, item) + } + + consistency, err := consistencyFromCmd(cmd) + if err != nil { + return err + } + + bulk := &v1.CheckBulkPermissionsRequest{ + Consistency: consistency, + Items: items, + } + + log.Trace().Interface("request", bulk).Send() + + ctx := cmd.Context() + c, err := client.NewClient(cmd) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") { + bulk.WithTracing = true + } + + resp, err := c.CheckBulkPermissions(ctx, bulk) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + return nil + } + + for _, item := range resp.Pairs { + console.Printf("%s:%s#%s@%s:%s => ", + item.Request.Resource.ObjectType, item.Request.Resource.ObjectId, item.Request.Permission, item.Request.Subject.Object.ObjectType, item.Request.Subject.Object.ObjectId) + + switch responseType := item.Response.(type) { + case *v1.CheckBulkPermissionsPair_Item: + switch responseType.Item.Permissionship { + case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION: + console.Println("caveated") + + case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION: + console.Println("true") + + case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION: + console.Println("false") + } + + err = displayDebugInformationIfRequested(cmd, responseType.Item.DebugTrace, nil, false) + if err != nil { + return err + } + + case *v1.CheckBulkPermissionsPair_Error: + console.Println(fmt.Sprintf("error: %s", responseType.Error)) + } + } + + return nil +} + +func expandCmdFunc(cmd *cobra.Command, args []string) error { + relation := args[0] + + var objectNS, objectID string + err := stringz.SplitExact(args[1], ":", &objectNS, &objectID) + if err != nil { + return err + } + + consistency, err := consistencyFromCmd(cmd) + if err != nil { + return err + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + + request := &v1.ExpandPermissionTreeRequest{ + Resource: &v1.ObjectReference{ + ObjectType: objectNS, + ObjectId: objectID, + }, + Permission: relation, + Consistency: consistency, + } + log.Trace().Interface("request", request).Send() + + resp, err := client.ExpandPermissionTree(cmd.Context(), request) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + return nil + } + + tp := printers.NewTreePrinter() + printers.TreeNodeTree(tp, resp.TreeRoot) + tp.Print() + + return nil +} + +var newLookupResourcesPageCallbackForTests func(readByPage uint) + +func lookupResourcesCmdFunc(cmd *cobra.Command, args []string) error { + objectNS := args[0] + relation := args[1] + subjectNS, subjectID, subjectRel, err := ParseSubject(args[2]) + if err != nil { + return err + } + + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + caveatContext, err := GetCaveatContext(cmd) + if err != nil { + return err + } + + consistency, err := consistencyFromCmd(cmd) + if err != nil { + return err + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + + var cursor *v1.Cursor + if cursorStr := cobrautil.MustGetString(cmd, "cursor"); cursorStr != "" { + cursor = &v1.Cursor{Token: cursorStr} + } + + var totalCount uint + for { + request := &v1.LookupResourcesRequest{ + ResourceObjectType: objectNS, + Permission: relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: subjectNS, + ObjectId: subjectID, + }, + OptionalRelation: subjectRel, + }, + Context: caveatContext, + Consistency: consistency, + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } + log.Trace().Interface("request", request).Uint32("page-limit", pageLimit).Send() + + respStream, err := client.LookupResources(cmd.Context(), request) + if err != nil { + return err + } + + var count uint + + stream: + for { + resp, err := respStream.Recv() + switch { + case errors.Is(err, io.EOF): + break stream + case err != nil: + return err + default: + count++ + totalCount++ + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + } + + console.Println(prettyLookupPermissionship(resp.ResourceObjectId, resp.Permissionship, resp.PartialCaveatInfo)) + cursor = resp.AfterResultCursor + } + } + + if newLookupResourcesPageCallbackForTests != nil { + newLookupResourcesPageCallbackForTests(count) + } + if count == 0 || pageLimit == 0 || count < uint(pageLimit) { + log.Trace().Interface("request", request).Uint32("page-limit", pageLimit).Uint("count", totalCount).Send() + break + } + } + + showCursor := cobrautil.MustGetBool(cmd, "show-cursor") + if showCursor && cursor != nil { + console.Printf("Last cursor: %s\n", cursor.Token) + } + + return nil +} + +func lookupSubjectsCmdFunc(cmd *cobra.Command, args []string) error { + var objectNS, objectID string + err := stringz.SplitExact(args[0], ":", &objectNS, &objectID) + if err != nil { + return err + } + + permission := args[1] + + subjectType, subjectRelation := ParseType(args[2]) + + caveatContext, err := GetCaveatContext(cmd) + if err != nil { + return err + } + + consistency, err := consistencyFromCmd(cmd) + if err != nil { + return err + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + request := &v1.LookupSubjectsRequest{ + Resource: &v1.ObjectReference{ + ObjectType: objectNS, + ObjectId: objectID, + }, + Permission: permission, + SubjectObjectType: subjectType, + OptionalSubjectRelation: subjectRelation, + Context: caveatContext, + Consistency: consistency, + } + log.Trace().Interface("request", request).Send() + + respStream, err := client.LookupSubjects(cmd.Context(), request) + if err != nil { + return err + } + + for { + resp, err := respStream.Recv() + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return err + default: + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + } + console.Printf("%s:%s%s\n", + subjectType, + prettyLookupPermissionship(resp.Subject.SubjectObjectId, resp.Subject.Permissionship, resp.Subject.PartialCaveatInfo), + excludedSubjectsString(resp.ExcludedSubjects), + ) + } + } +} + +func excludedSubjectsString(excluded []*v1.ResolvedSubject) string { + if len(excluded) == 0 { + return "" + } + + var b strings.Builder + fmt.Fprintf(&b, " - {\n") + for _, subj := range excluded { + fmt.Fprintf(&b, "\t%s\n", prettyLookupPermissionship( + subj.SubjectObjectId, + subj.Permissionship, + subj.PartialCaveatInfo, + )) + } + fmt.Fprintf(&b, "}") + return b.String() +} + +func prettyLookupPermissionship(objectID string, p v1.LookupPermissionship, info *v1.PartialCaveatInfo) string { + var b strings.Builder + fmt.Fprint(&b, objectID) + if p == v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION { + fmt.Fprintf(&b, " (caveated, missing context: %s)", strings.Join(info.MissingRequiredContext, ", ")) + } + return b.String() +} + +func displayDebugInformationIfRequested(cmd *cobra.Command, debug *v1.DebugInformation, trailerMD metadata.MD, hasError bool) error { + if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") { + debugInfo := &v1.DebugInformation{} + // DebugInformation comes in trailer < 1.30, and in response payload >= 1.30 + if debug == nil { + found, err := responsemeta.GetResponseTrailerMetadataOrNil(trailerMD, responsemeta.DebugInformation) + if err != nil { + return err + } + + if found == nil { + log.Warn().Msg("No debugging information returned for the check") + return nil + } + + err = protojson.Unmarshal([]byte(*found), debugInfo) + if err != nil { + return err + } + } else { + debugInfo = debug + } + + if debugInfo.Check == nil { + log.Warn().Msg("No trace found for the check") + return nil + } + + if cobrautil.MustGetBool(cmd, "explain") { + tp := printers.NewTreePrinter() + printers.DisplayCheckTrace(debugInfo.Check, tp, hasError) + tp.Print() + } + + if cobrautil.MustGetBool(cmd, "schema") { + console.Println() + console.Println(debugInfo.SchemaUsed) + } + } + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship.go b/vendor/github.com/authzed/zed/internal/commands/relationship.go new file mode 100644 index 0000000..0471f85 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/relationship.go @@ -0,0 +1,561 @@ +package commands + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "strings" + "time" + "unicode" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" +) + +func RegisterRelationshipCmd(rootCmd *cobra.Command) *cobra.Command { + rootCmd.AddCommand(relationshipCmd) + + relationshipCmd.AddCommand(createCmd) + createCmd.Flags().Bool("json", false, "output as JSON") + createCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`) + createCmd.Flags().String("expiration-time", "", `the expiration time of the relationship in RFC 3339 format`) + createCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin") + + relationshipCmd.AddCommand(touchCmd) + touchCmd.Flags().Bool("json", false, "output as JSON") + touchCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`) + touchCmd.Flags().String("expiration-time", "", `the expiration time for the relationship in RFC 3339 format`) + touchCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin") + + relationshipCmd.AddCommand(deleteCmd) + deleteCmd.Flags().Bool("json", false, "output as JSON") + deleteCmd.Flags().IntP("batch-size", "b", 100, "batch size when deleting streams of relationships from stdin") + + relationshipCmd.AddCommand(readCmd) + readCmd.Flags().Bool("json", false, "output as JSON") + readCmd.Flags().String("revision", "", "optional revision at which to check") + _ = readCmd.Flags().MarkHidden("revision") + readCmd.Flags().String("subject-filter", "", "optional subject filter") + readCmd.Flags().Uint32("page-limit", 100, "limit of relations returned per page") + registerConsistencyFlags(readCmd.Flags()) + + relationshipCmd.AddCommand(bulkDeleteCmd) + bulkDeleteCmd.Flags().Bool("force", false, "force deletion of all elements in batches defined by <optional-limit>") + bulkDeleteCmd.Flags().String("subject-filter", "", "optional subject filter") + bulkDeleteCmd.Flags().Uint32("optional-limit", 1000, "the max amount of elements to delete. If you want to delete all in batches of size <optional-limit>, set --force to true") + bulkDeleteCmd.Flags().Bool("estimate-count", true, "estimate the count of relationships to be deleted") + _ = bulkDeleteCmd.Flags().MarkDeprecated("estimate-count", "no longer used, make use of --optional-limit instead") + return relationshipCmd +} + +var relationshipCmd = &cobra.Command{ + Use: "relationship <subcommand>", + Short: "Query and mutate the relationships in a permissions system", +} + +var createCmd = &cobra.Command{ + Use: "create <resource:id> <relation> <subject:id#optional_subject_relation>", + Short: "Create a relationship for a subject", + Args: ValidationWrapper(StdinOrExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_CREATE, os.Stdin), +} + +var touchCmd = &cobra.Command{ + Use: "touch <resource:id> <relation> <subject:id#optional_subject_relation>", + Short: "Idempotently updates a relationship for a subject", + Args: ValidationWrapper(StdinOrExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_TOUCH, os.Stdin), +} + +var deleteCmd = &cobra.Command{ + Use: "delete <resource:id> <relation> <subject:id#optional_subject_relation>", + Short: "Deletes a relationship", + Args: ValidationWrapper(StdinOrExactArgs(3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_DELETE, os.Stdin), +} + +const readCmdHelpLong = `Enumerates relationships matching the provided pattern. + +To filter returned relationships using a resource ID prefix, append a '%' to the resource ID: + +zed relationship read some-type:some-prefix-% +` + +var readCmd = &cobra.Command{ + Use: "read <resource_type:optional_resource_id> <optional_relation> <optional_subject_type:optional_subject_id#optional_subject_relation>", + Short: "Enumerates relationships matching the provided pattern", + Long: readCmdHelpLong, + Args: ValidationWrapper(cobra.RangeArgs(1, 3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: readRelationships, +} + +var bulkDeleteCmd = &cobra.Command{ + Use: "bulk-delete <resource_type:optional_resource_id> <optional_relation> <optional_subject_type:optional_subject_id#optional_subject_relation>", + Short: "Deletes relationships matching the provided pattern en masse", + Args: ValidationWrapper(cobra.RangeArgs(1, 3)), + ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), + RunE: bulkDeleteRelationships, +} + +func StdinOrExactArgs(n int) cobra.PositionalArgs { + return func(cmd *cobra.Command, args []string) error { + if ok := isArgsViaFile(os.Stdin) && len(args) == 0; ok { + return nil + } + + return cobra.ExactArgs(n)(cmd, args) + } +} + +func isArgsViaFile(file *os.File) bool { + return !isFileTerminal(file) +} + +func bulkDeleteRelationships(cmd *cobra.Command, args []string) error { + spicedbClient, err := client.NewClient(cmd) + if err != nil { + return err + } + + filter, err := buildRelationshipsFilter(cmd, args) + if err != nil { + return err + } + + bar := console.CreateProgressBar("deleting relationships") + defer func() { + _ = bar.Finish() + }() + + allowPartialDeletions := cobrautil.MustGetBool(cmd, "force") + optionalLimit := cobrautil.MustGetUint32(cmd, "optional-limit") + + var resp *v1.DeleteRelationshipsResponse + for { + delRequest := &v1.DeleteRelationshipsRequest{ + RelationshipFilter: filter, + OptionalLimit: optionalLimit, + OptionalAllowPartialDeletions: allowPartialDeletions, + } + log.Trace().Interface("request", delRequest).Msg("deleting relationships") + + resp, err = spicedbClient.DeleteRelationships(cmd.Context(), delRequest) + if errorInfo, ok := grpcErrorInfoFrom(err); ok { + if errorInfo.GetReason() == v1.ErrorReason_ERROR_REASON_TOO_MANY_RELATIONSHIPS_FOR_TRANSACTIONAL_DELETE.String() { + resourceType := "relationships" + if returnedResourceType, ok := errorInfo.GetMetadata()["filter_resource_type"]; ok { + resourceType = returnedResourceType + } + + return fmt.Errorf("could not delete %s, as more than %s relationships were found. Consider increasing --optional-limit or deleting all relationships using --force", + resourceType, + errorInfo.GetMetadata()["limit"]) + } + } + if err != nil { + return err + } + + if resp.DeletionProgress == v1.DeleteRelationshipsResponse_DELETION_PROGRESS_COMPLETE { + break + } + + if err := bar.Add(int(optionalLimit)); err != nil { + return err + } + } + + _ = bar.Finish() + console.Println(resp.DeletedAt.GetToken()) + return nil +} + +func grpcErrorInfoFrom(err error) (*errdetails.ErrorInfo, bool) { + if err == nil { + return nil, false + } + + if s, ok := status.FromError(err); ok { + for _, d := range s.Details() { + if errInfo, ok := d.(*errdetails.ErrorInfo); ok { + return errInfo, true + } + } + } + + return nil, false +} + +func buildRelationshipsFilter(cmd *cobra.Command, args []string) (*v1.RelationshipFilter, error) { + filter := &v1.RelationshipFilter{ResourceType: args[0]} + + if strings.Contains(args[0], ":") { + var resourceID string + err := stringz.SplitExact(args[0], ":", &filter.ResourceType, &resourceID) + if err != nil { + return nil, err + } + + if strings.HasSuffix(resourceID, "%") { + filter.OptionalResourceIdPrefix = strings.TrimSuffix(resourceID, "%") + } else { + filter.OptionalResourceId = resourceID + } + } + + if len(args) > 1 { + filter.OptionalRelation = args[1] + } + + subjectFilter := cobrautil.MustGetString(cmd, "subject-filter") + if len(args) == 3 { + if subjectFilter != "" { + return nil, errors.New("cannot specify subject filter both positionally and via --subject-filter") + } + subjectFilter = args[2] + } + + if subjectFilter != "" { + if strings.Contains(subjectFilter, ":") { + subjectNS, subjectID, subjectRel, err := ParseSubject(subjectFilter) + if err != nil { + return nil, err + } + + filter.OptionalSubjectFilter = &v1.SubjectFilter{ + SubjectType: subjectNS, + OptionalSubjectId: subjectID, + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: subjectRel, + }, + } + } else { + filter.OptionalSubjectFilter = &v1.SubjectFilter{ + SubjectType: subjectFilter, + } + } + } + + return filter, nil +} + +func readRelationships(cmd *cobra.Command, args []string) error { + spicedbClient, err := client.NewClient(cmd) + if err != nil { + return err + } + + filter, err := buildRelationshipsFilter(cmd, args) + if err != nil { + return err + } + + request := &v1.ReadRelationshipsRequest{RelationshipFilter: filter} + + limit := cobrautil.MustGetUint32(cmd, "page-limit") + request.OptionalLimit = limit + request.Consistency, err = consistencyFromCmd(cmd) + if err != nil { + return err + } + + lastCursor := request.OptionalCursor + for { + request.OptionalCursor = lastCursor + var cursorToken string + if lastCursor != nil { + cursorToken = lastCursor.Token + } + log.Trace().Interface("request", request).Str("cursor", cursorToken).Msg("reading relationships page") + readRelClient, err := spicedbClient.ReadRelationships(cmd.Context(), request) + if err != nil { + return err + } + + var relCount uint32 + for { + if err := cmd.Context().Err(); err != nil { + return err + } + + msg, err := readRelClient.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return err + } + + lastCursor = msg.AfterResultCursor + relCount++ + if err := printRelationship(cmd, msg); err != nil { + return err + } + } + + if relCount < limit || limit == 0 { + return nil + } + + if relCount > limit { + log.Warn().Uint32("limit-specified", limit).Uint32("relationships-received", relCount).Msg("page limit ignored, pagination may not be supported by the server, consider updating SpiceDB") + return nil + } + } +} + +func printRelationship(cmd *cobra.Command, msg *v1.ReadRelationshipsResponse) error { + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(msg) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + } else { + relString, err := relationshipToString(msg.Relationship) + if err != nil { + return err + } + console.Println(relString) + } + + return nil +} + +func argsToRelationship(args []string) (*v1.Relationship, error) { + if len(args) != 3 { + return nil, fmt.Errorf("expected 3 arguments, but got %d", len(args)) + } + + rel, err := tupleToRel(args[0], args[1], args[2]) + if err != nil { + return nil, errors.New("failed to parse input arguments") + } + + return rel, nil +} + +func relationshipToString(rel *v1.Relationship) (string, error) { + relString, err := tuple.V1StringRelationship(rel) + if err != nil { + return "", err + } + + relString = strings.Replace(relString, "@", " ", 1) + relString = strings.Replace(relString, "#", " ", 1) + return relString, nil +} + +// parseRelationshipLine splits a line of update input that comes from stdin +// and returns the fields representing the 3 arguments. This is to handle +// the fact that relationships specified via stdin can't escape spaces like +// shell arguments. +func parseRelationshipLine(line string) (string, string, string, error) { + line = strings.TrimSpace(line) + resourceIdx := strings.IndexFunc(line, unicode.IsSpace) + if resourceIdx == -1 { + args := 0 + if line != "" { + args = 1 + } + return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got %v", line, args) + } + + resource := line[:resourceIdx] + rest := strings.TrimSpace(line[resourceIdx+1:]) + relationIdx := strings.IndexFunc(rest, unicode.IsSpace) + if relationIdx == -1 { + args := 1 + if strings.TrimSpace(rest) != "" { + args = 2 + } + return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got %v", line, args) + } + + relation := rest[:relationIdx] + rest = strings.TrimSpace(rest[relationIdx+1:]) + if rest == "" { + return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got 2", line) + } + + return resource, relation, rest, nil +} + +func FileRelationshipParser(f *os.File) RelationshipParser { + scanner := bufio.NewScanner(f) + return func() (*v1.Relationship, error) { + if scanner.Scan() { + res, rel, subj, err := parseRelationshipLine(scanner.Text()) + if err != nil { + return nil, err + } + return tupleToRel(res, rel, subj) + } + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, ErrExhaustedRelationships + } +} + +func tupleToRel(resource, relation, subject string) (*v1.Relationship, error) { + return tuple.ParseV1Rel(resource + "#" + relation + "@" + subject) +} + +func SliceRelationshipParser(args []string) RelationshipParser { + ran := false + return func() (*v1.Relationship, error) { + if ran { + return nil, ErrExhaustedRelationships + } + ran = true + return tupleToRel(args[0], args[1], args[2]) + } +} + +func writeUpdates(ctx context.Context, spicedbClient client.Client, updates []*v1.RelationshipUpdate, json bool) error { + if len(updates) == 0 { + return nil + } + request := &v1.WriteRelationshipsRequest{ + Updates: updates, + OptionalPreconditions: nil, + } + + log.Trace().Interface("request", request).Msg("writing relationships") + resp, err := spicedbClient.WriteRelationships(ctx, request) + if err != nil { + return err + } + + if json { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + } else { + console.Println(resp.WrittenAt.GetToken()) + } + + return nil +} + +// RelationshipParser is a closure that can produce relationships. +// When there are no more relationships, it will return ErrExhaustedRelationships. +type RelationshipParser func() (*v1.Relationship, error) + +// ErrExhaustedRelationships signals that the last producible value of a RelationshipParser +// has already been consumed. +// Functions should return this error to signal a graceful end of input. +var ErrExhaustedRelationships = errors.New("exhausted all relationships") + +func writeRelationshipCmdFunc(operation v1.RelationshipUpdate_Operation, input *os.File) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, args []string) error { + parser := SliceRelationshipParser(args) + if isArgsViaFile(input) && len(args) == 0 { + parser = FileRelationshipParser(input) + } + + spicedbClient, err := client.NewClient(cmd) + if err != nil { + return err + } + + batchSize := cobrautil.MustGetInt(cmd, "batch-size") + updateBatch := make([]*v1.RelationshipUpdate, 0) + doJSON := cobrautil.MustGetBool(cmd, "json") + + for { + rel, err := parser() + if errors.Is(err, ErrExhaustedRelationships) { + return writeUpdates(cmd.Context(), spicedbClient, updateBatch, doJSON) + } else if err != nil { + return err + } + + if operation != v1.RelationshipUpdate_OPERATION_DELETE { + if err := handleCaveatFlag(cmd, rel); err != nil { + return err + } + + if err := handleExpirationFlag(cmd, rel); err != nil { + return err + } + } + + updateBatch = append(updateBatch, &v1.RelationshipUpdate{ + Operation: operation, + Relationship: rel, + }) + if len(updateBatch) == batchSize { + if err := writeUpdates(cmd.Context(), spicedbClient, updateBatch, doJSON); err != nil { + return err + } + updateBatch = nil + } + } + } +} + +func handleCaveatFlag(cmd *cobra.Command, rel *v1.Relationship) error { + caveatString := cobrautil.MustGetString(cmd, "caveat") + if caveatString != "" { + if rel.OptionalCaveat != nil { + return errors.New("cannot specify a caveat in both the relationship and the --caveat flag") + } + + parts := strings.SplitN(caveatString, ":", 2) + if len(parts) == 0 { + return fmt.Errorf("invalid --caveat argument. Must be in format `caveat_name:context`, but found `%s`", caveatString) + } + + rel.OptionalCaveat = &v1.ContextualizedCaveat{ + CaveatName: parts[0], + } + + if len(parts) == 2 { + caveatCtx, err := ParseCaveatContext(parts[1]) + if err != nil { + return err + } + rel.OptionalCaveat.Context = caveatCtx + } + } + return nil +} + +func handleExpirationFlag(cmd *cobra.Command, rel *v1.Relationship) error { + expirationTime := cobrautil.MustGetString(cmd, "expiration-time") + + if expirationTime != "" { + t, err := time.Parse(time.RFC3339, expirationTime) + if err != nil { + return fmt.Errorf("could not parse RFC 3339 timestamp: %w", err) + } + rel.OptionalExpiresAt = timestamppb.New(t) + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go b/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go new file mode 100644 index 0000000..ea94114 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go @@ -0,0 +1,12 @@ +//go:build !wasm +// +build !wasm + +package commands + +import ( + "os" + + "golang.org/x/term" +) + +var isFileTerminal = func(f *os.File) bool { return term.IsTerminal(int(f.Fd())) } diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go b/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go new file mode 100644 index 0000000..4388932 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go @@ -0,0 +1,5 @@ +package commands + +import "os" + +var isFileTerminal = func(f *os.File) bool { return true } diff --git a/vendor/github.com/authzed/zed/internal/commands/schema.go b/vendor/github.com/authzed/zed/internal/commands/schema.go new file mode 100644 index 0000000..b56f152 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/schema.go @@ -0,0 +1,87 @@ +package commands + +import ( + "context" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" +) + +func RegisterSchemaCmd(rootCmd *cobra.Command) *cobra.Command { + rootCmd.AddCommand(schemaCmd) + + schemaCmd.AddCommand(schemaReadCmd) + schemaReadCmd.Flags().Bool("json", false, "output as JSON") + + return schemaCmd +} + +var ( + schemaCmd = &cobra.Command{ + Use: "schema <subcommand>", + Short: "Manage schema for a permissions system", + } + + schemaReadCmd = &cobra.Command{ + Use: "read", + Short: "Read the schema of a permissions system", + Args: ValidationWrapper(cobra.ExactArgs(0)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: schemaReadCmdFunc, + } +) + +func schemaReadCmdFunc(cmd *cobra.Command, _ []string) error { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + request := &v1.ReadSchemaRequest{} + log.Trace().Interface("request", request).Msg("requesting schema read") + + resp, err := client.ReadSchema(cmd.Context(), request) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := PrettyProto(resp) + if err != nil { + return err + } + + console.Println(string(prettyProto)) + return nil + } + + console.Println(stringz.Join("\n\n", resp.SchemaText)) + return nil +} + +// ReadSchema calls read schema for the client and returns the schema found. +func ReadSchema(ctx context.Context, client client.Client) (string, error) { + request := &v1.ReadSchemaRequest{} + log.Trace().Interface("request", request).Msg("requesting schema read") + + resp, err := client.ReadSchema(ctx, request) + if err != nil { + errStatus, ok := status.FromError(err) + if !ok || errStatus.Code() != codes.NotFound { + return "", err + } + + log.Debug().Msg("no schema defined") + return "", nil + } + + return resp.SchemaText, nil +} diff --git a/vendor/github.com/authzed/zed/internal/commands/util.go b/vendor/github.com/authzed/zed/internal/commands/util.go new file mode 100644 index 0000000..6d5b4da --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/util.go @@ -0,0 +1,123 @@ +package commands + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/TylerBrock/colorjson" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/spf13/cobra" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/authzed-go/pkg/requestmeta" +) + +// ParseSubject parses the given subject string into its namespace, object ID +// and relation, if valid. +func ParseSubject(s string) (namespace, id, relation string, err error) { + err = stringz.SplitExact(s, ":", &namespace, &id) + if err != nil { + return + } + err = stringz.SplitExact(id, "#", &id, &relation) + if err != nil { + relation = "" + err = nil + } + return +} + +// ParseType parses a type reference of the form `namespace#relaion`. +func ParseType(s string) (namespace, relation string) { + namespace, relation, _ = strings.Cut(s, "#") + return +} + +// GetCaveatContext returns the entered caveat caveat, if any. +func GetCaveatContext(cmd *cobra.Command) (*structpb.Struct, error) { + contextString := cobrautil.MustGetString(cmd, "caveat-context") + if len(contextString) == 0 { + return nil, nil + } + + return ParseCaveatContext(contextString) +} + +// ParseCaveatContext parses the given context JSON string into caveat context, +// if valid. +func ParseCaveatContext(contextString string) (*structpb.Struct, error) { + contextMap := map[string]any{} + err := json.Unmarshal([]byte(contextString), &contextMap) + if err != nil { + return nil, fmt.Errorf("invalid caveat context JSON: %w", err) + } + + context, err := structpb.NewStruct(contextMap) + if err != nil { + return nil, fmt.Errorf("could not construct caveat context: %w", err) + } + return context, err +} + +// PrettyProto returns the given protocol buffer formatted into pretty text. +func PrettyProto(m proto.Message) ([]byte, error) { + encoded, err := protojson.Marshal(m) + if err != nil { + return nil, err + } + var obj interface{} + err = json.Unmarshal(encoded, &obj) + if err != nil { + panic("protojson decode failed: " + err.Error()) + } + + f := colorjson.NewFormatter() + f.Indent = 2 + pretty, err := f.Marshal(obj) + if err != nil { + panic("colorjson encode failed: " + err.Error()) + } + + return pretty, nil +} + +// InjectRequestID adds the value of the --request-id flag to the +// context of the given command. +func InjectRequestID(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + requestID := cobrautil.MustGetString(cmd, "request-id") + if ctx != nil && requestID != "" { + cmd.SetContext(requestmeta.WithRequestID(ctx, requestID)) + } + + return nil +} + +// ValidationError is used to wrap errors that are cobra validation errors. It should be used to +// wrap the Command.PositionalArgs function in order to be able to determine if the error is a validation error. +// This is used to determine if an error should print the usage string. Unfortunately Cobra parameter parsing +// and parameter validation are handled differently, and the latter does not trigger calling Command.FlagErrorFunc +type ValidationError struct { + error +} + +func (ve ValidationError) Is(err error) bool { + var validationError ValidationError + return errors.As(err, &validationError) +} + +// ValidationWrapper is used to be able to determine if an error is a validation error. +func ValidationWrapper(f cobra.PositionalArgs) cobra.PositionalArgs { + return func(cmd *cobra.Command, args []string) error { + if err := f(cmd, args); err != nil { + return ValidationError{error: err} + } + + return nil + } +} diff --git a/vendor/github.com/authzed/zed/internal/commands/watch.go b/vendor/github.com/authzed/zed/internal/commands/watch.go new file mode 100644 index 0000000..96b3451 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/commands/watch.go @@ -0,0 +1,212 @@ +package commands + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/spf13/cobra" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" +) + +var ( + watchObjectTypes []string + watchRevision string + watchTimestamps bool + watchRelationshipFilters []string +) + +func RegisterWatchCmd(rootCmd *cobra.Command) *cobra.Command { + rootCmd.AddCommand(watchCmd) + + watchCmd.Flags().StringSliceVar(&watchObjectTypes, "object_types", nil, "optional object types to watch updates for") + watchCmd.Flags().StringVar(&watchRevision, "revision", "", "optional revision at which to start watching") + watchCmd.Flags().BoolVar(&watchTimestamps, "timestamp", false, "shows timestamp of incoming update events") + return watchCmd +} + +func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command { + parentCmd.AddCommand(watchRelationshipsCmd) + watchRelationshipsCmd.Flags().StringSliceVar(&watchObjectTypes, "object_types", nil, "optional object types to watch updates for") + watchRelationshipsCmd.Flags().StringVar(&watchRevision, "revision", "", "optional revision at which to start watching") + watchRelationshipsCmd.Flags().BoolVar(&watchTimestamps, "timestamp", false, "shows timestamp of incoming update events") + watchRelationshipsCmd.Flags().StringSliceVar(&watchRelationshipFilters, "filter", nil, "optional filter(s) for the watch stream. Example: `optional_resource_type:optional_resource_id_or_prefix#optional_relation@optional_subject_filter`") + return watchRelationshipsCmd +} + +var watchCmd = &cobra.Command{ + Use: "watch [object_types, ...] [start_cursor]", + Short: "Watches the stream of relationship updates from the server", + Args: ValidationWrapper(cobra.RangeArgs(0, 2)), + RunE: watchCmdFunc, + Deprecated: "deprecated; please use `zed watch relationships` instead", +} + +var watchRelationshipsCmd = &cobra.Command{ + Use: "watch [object_types, ...] [start_cursor]", + Short: "Watches the stream of relationship updates from the server", + Args: ValidationWrapper(cobra.RangeArgs(0, 2)), + RunE: watchCmdFunc, +} + +func watchCmdFunc(cmd *cobra.Command, _ []string) error { + console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision) + + cli, err := client.NewClient(cmd) + if err != nil { + return err + } + + relFilters := make([]*v1.RelationshipFilter, 0, len(watchRelationshipFilters)) + for _, filter := range watchRelationshipFilters { + relFilter, err := parseRelationshipFilter(filter) + if err != nil { + return err + } + relFilters = append(relFilters, relFilter) + } + + req := &v1.WatchRequest{ + OptionalObjectTypes: watchObjectTypes, + OptionalRelationshipFilters: relFilters, + } + if watchRevision != "" { + req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision} + } + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + defer interruptCancel() + + watchStream, err := cli.Watch(ctx, req) + if err != nil { + return err + } + + for { + select { + case <-signalctx.Done(): + console.Errorf("stream interrupted after program termination\n") + return nil + case <-ctx.Done(): + console.Errorf("stream canceled after context cancellation\n") + return nil + default: + resp, err := watchStream.Recv() + if err != nil { + return err + } + + for _, update := range resp.Updates { + if watchTimestamps { + console.Printf("%v: ", time.Now()) + } + + switch update.Operation { + case v1.RelationshipUpdate_OPERATION_CREATE: + console.Printf("CREATED ") + + case v1.RelationshipUpdate_OPERATION_DELETE: + console.Printf("DELETED ") + + case v1.RelationshipUpdate_OPERATION_TOUCH: + console.Printf("TOUCHED ") + } + + subjectRelation := "" + if update.Relationship.Subject.OptionalRelation != "" { + subjectRelation = " " + update.Relationship.Subject.OptionalRelation + } + + console.Printf("%s:%s %s %s:%s%s\n", + update.Relationship.Resource.ObjectType, + update.Relationship.Resource.ObjectId, + update.Relationship.Relation, + update.Relationship.Subject.Object.ObjectType, + update.Relationship.Subject.Object.ObjectId, + subjectRelation, + ) + } + } + } +} + +func parseRelationshipFilter(relFilterStr string) (*v1.RelationshipFilter, error) { + relFilter := &v1.RelationshipFilter{} + pieces := strings.Split(relFilterStr, "@") + if len(pieces) > 2 { + return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr) + } + + if len(pieces) == 2 { + subjectFilter, err := parseSubjectFilter(pieces[1]) + if err != nil { + return nil, err + } + relFilter.OptionalSubjectFilter = subjectFilter + } + + if len(pieces) > 0 { + resourcePieces := strings.Split(pieces[0], "#") + if len(resourcePieces) > 2 { + return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr) + } + + if len(resourcePieces) == 2 { + relFilter.OptionalRelation = resourcePieces[1] + } + + resourceTypePieces := strings.Split(resourcePieces[0], ":") + if len(resourceTypePieces) > 2 { + return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr) + } + + relFilter.ResourceType = resourceTypePieces[0] + if len(resourceTypePieces) == 2 { + optionalResourceIDOrPrefix := resourceTypePieces[1] + if strings.HasSuffix(optionalResourceIDOrPrefix, "%") { + relFilter.OptionalResourceIdPrefix = strings.TrimSuffix(optionalResourceIDOrPrefix, "%") + } else { + relFilter.OptionalResourceId = optionalResourceIDOrPrefix + } + } + } + + return relFilter, nil +} + +func parseSubjectFilter(subjectFilterStr string) (*v1.SubjectFilter, error) { + subjectFilter := &v1.SubjectFilter{} + pieces := strings.Split(subjectFilterStr, "#") + if len(pieces) > 2 { + return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr) + } + + subjectTypePieces := strings.Split(pieces[0], ":") + if len(subjectTypePieces) > 2 { + return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr) + } + + subjectFilter.SubjectType = subjectTypePieces[0] + if len(subjectTypePieces) == 2 { + subjectFilter.OptionalSubjectId = subjectTypePieces[1] + } + + if len(pieces) == 2 { + subjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{ + Relation: pieces[1], + } + } + + return subjectFilter, nil +} diff --git a/vendor/github.com/authzed/zed/internal/console/console.go b/vendor/github.com/authzed/zed/internal/console/console.go new file mode 100644 index 0000000..53f815c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/console/console.go @@ -0,0 +1,60 @@ +package console + +import ( + "fmt" + "os" + "time" + + "github.com/mattn/go-isatty" + "github.com/schollz/progressbar/v3" +) + +// Printf defines an (overridable) function for printing to the console via stdout. +var Printf = func(format string, a ...any) { + fmt.Printf(format, a...) +} + +var Print = func(a ...any) { + fmt.Print(a...) +} + +// Errorf defines an (overridable) function for printing to the console via stderr. +var Errorf = func(format string, a ...any) { + _, err := fmt.Fprintf(os.Stderr, format, a...) + if err != nil { + panic(err) + } +} + +// Println prints a line with optional values to the console. +var Println = func(values ...any) { + for _, value := range values { + Printf("%v\n", value) + } +} + +// CreateProgressBar creates a new progress bar with the given description and defaults adjusted to zed's UX experience +func CreateProgressBar(description string) *progressbar.ProgressBar { + bar := progressbar.NewOptions(-1, + progressbar.OptionSetWidth(10), + progressbar.OptionSetRenderBlankState(true), + progressbar.OptionSetVisibility(false), + ) + if isatty.IsTerminal(os.Stderr.Fd()) { + bar = progressbar.NewOptions64(-1, + progressbar.OptionSetDescription(description), + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionSetWidth(10), + progressbar.OptionThrottle(65*time.Millisecond), + progressbar.OptionShowCount(), + progressbar.OptionShowIts(), + progressbar.OptionSetItsString("relationship"), + progressbar.OptionOnCompletion(func() { _, _ = fmt.Fprint(os.Stderr, "\n") }), + progressbar.OptionSpinnerType(14), + progressbar.OptionFullWidth(), + progressbar.OptionSetRenderBlankState(true), + ) + } + + return bar +} diff --git a/vendor/github.com/authzed/zed/internal/decode/decoder.go b/vendor/github.com/authzed/zed/internal/decode/decoder.go new file mode 100644 index 0000000..f8a02ad --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/decode/decoder.go @@ -0,0 +1,215 @@ +package decode + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" + + composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + "github.com/authzed/spicedb/pkg/composableschemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + "github.com/authzed/spicedb/pkg/validationfile/blocks" +) + +var playgroundPattern = regexp.MustCompile("^.*/s/.*/schema|relationships|assertions|expected.*$") + +// SchemaRelationships holds the schema (as a string) and a list of +// relationships (as a string) in the format from the devtools download API. +type SchemaRelationships struct { + Schema string `yaml:"schema"` + Relationships string `yaml:"relationships"` +} + +// Func will decode into the supplied object. +type Func func(out interface{}) ([]byte, bool, error) + +// DecoderForURL returns the appropriate decoder for a given URL. +// Some URLs have special handling to dereference to the actual file. +func DecoderForURL(u *url.URL) (d Func, err error) { + switch s := u.Scheme; s { + case "", "file": + d = fileDecoder(u) + case "http", "https": + d = httpDecoder(u) + default: + err = fmt.Errorf("%s scheme not supported", s) + } + return +} + +func fileDecoder(u *url.URL) Func { + return func(out interface{}) ([]byte, bool, error) { + file, err := os.Open(u.Path) + if err != nil { + return nil, false, err + } + data, err := io.ReadAll(file) + if err != nil { + return nil, false, err + } + isOnlySchema, err := unmarshalAsYAMLOrSchemaWithFile(data, out, u.Path) + return data, isOnlySchema, err + } +} + +func httpDecoder(u *url.URL) Func { + rewriteURL(u) + return directHTTPDecoder(u) +} + +func rewriteURL(u *url.URL) { + // match playground urls + if playgroundPattern.MatchString(u.Path) { + u.Path = u.Path[:strings.LastIndex(u.Path, "/")] + u.Path += "/download" + return + } + + switch u.Hostname() { + case "gist.github.com": + u.Host = "gist.githubusercontent.com" + u.Path = path.Join(u.Path, "/raw") + case "pastebin.com": + if ok, _ := path.Match("/raw/*", u.Path); ok { + return + } + u.Path = path.Join("/raw/", u.Path) + } +} + +func directHTTPDecoder(u *url.URL) Func { + return func(out interface{}) ([]byte, bool, error) { + log.Debug().Stringer("url", u).Send() + r, err := http.Get(u.String()) + if err != nil { + return nil, false, err + } + defer r.Body.Close() + data, err := io.ReadAll(r.Body) + if err != nil { + return nil, false, err + } + + isOnlySchema, err := unmarshalAsYAMLOrSchema("", data, out) + return data, isOnlySchema, err + } +} + +// Uses the files passed in the args and looks for the specified schemaFile to parse the YAML. +func unmarshalAsYAMLOrSchemaWithFile(data []byte, out interface{}, filename string) (bool, error) { + if strings.Contains(string(data), "schemaFile:") && !strings.Contains(string(data), "schema:") { + if err := yaml.Unmarshal(data, out); err != nil { + return false, err + } + validationFile, ok := out.(*validationfile.ValidationFile) + if !ok { + return false, fmt.Errorf("could not cast unmarshalled file to validationfile") + } + + // Need to join the original filepath with the requested filepath + // to construct the path to the referenced schema file. + // NOTE: This does not allow for yaml files to transitively reference + // each other's schemaFile fields. + // TODO: enable this behavior + schemaPath := filepath.Join(path.Dir(filename), validationFile.SchemaFile) + + if !filepath.IsLocal(schemaPath) { + // We want to prevent access of files that are outside of the folder + // where the command was originally invoked. This should do that. + return false, fmt.Errorf("schema filepath %s must be local to where the command was invoked", schemaPath) + } + + file, err := os.Open(schemaPath) + if err != nil { + return false, err + } + data, err = io.ReadAll(file) + if err != nil { + return false, err + } + } + return unmarshalAsYAMLOrSchema(filename, data, out) +} + +func unmarshalAsYAMLOrSchema(filename string, data []byte, out interface{}) (bool, error) { + inputString := string(data) + + // Check for indications of a schema-only file. + if !strings.Contains(inputString, "schema:") && !strings.Contains(inputString, "relationships:") { + if err := compileSchemaFromData(filename, inputString, out); err != nil { + return false, err + } + return true, nil + } + + if !strings.Contains(inputString, "schema:") && !strings.Contains(inputString, "schemaFile:") { + // If there is no schema and no schemaFile and it doesn't compile then it must be yaml with missing fields + if err := compileSchemaFromData(filename, inputString, out); err != nil { + return false, errors.New("either schema or schemaFile must be present") + } + return true, nil + } + // Try to unparse as YAML for the validation file format. + if err := yaml.Unmarshal(data, out); err != nil { + return false, err + } + + return false, nil +} + +// compileSchemaFromData attempts to compile using the old DSL and the new composable DSL, +// but prefers the new DSL. +// It returns the errors returned by both compilations. +func compileSchemaFromData(filename, schemaString string, out interface{}) error { + var ( + standardCompileErr error + composableCompiled *composable.CompiledSchema + composableCompileErr error + vfile validationfile.ValidationFile + ) + + vfile = *out.(*validationfile.ValidationFile) + vfile.Schema = blocks.ParsedSchema{ + SourcePosition: spiceerrors.SourcePosition{LineNumber: 1, ColumnPosition: 1}, + } + + _, standardCompileErr = compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schemaString, + }, compiler.AllowUnprefixedObjectType()) + + if standardCompileErr == nil { + vfile.Schema.Schema = schemaString + } + + inputSourceFolder := filepath.Dir(filename) + composableCompiled, composableCompileErr = composable.Compile(composable.InputSchema{ + SchemaString: schemaString, + }, composable.AllowUnprefixedObjectType(), composable.SourceFolder(inputSourceFolder)) + + if composableCompileErr == nil { + compiledSchemaString, _, err := generator.GenerateSchema(composableCompiled.OrderedDefinitions) + if err != nil { + return fmt.Errorf("could not generate string schema: %w", err) + } + vfile.Schema.Schema = compiledSchemaString + } + + err := errors.Join(standardCompileErr, composableCompileErr) + + *out.(*validationfile.ValidationFile) = vfile + return err +} diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/batch.go b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go new file mode 100644 index 0000000..640085c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go @@ -0,0 +1,68 @@ +package grpcutil + +import ( + "context" + "errors" + "fmt" + "runtime" + + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +func minimum(a int, b int) int { + if a <= b { + return a + } + return b +} + +// EachFunc is a callback function that is called for each batch. no is the +// batch number, start is the starting index of this batch in the slice, and +// end is the ending index of this batch in the slice. +type EachFunc func(ctx context.Context, no int, start int, end int) error + +// ConcurrentBatch will calculate the minimum number of batches to required to batch n items +// with batchSize batches. For each batch, it will execute the each function. +// These functions will be processed in parallel using maxWorkers number of +// goroutines. If maxWorkers is 1, then batching will happen sychronously. If +// maxWorkers is 0, then GOMAXPROCS number of workers will be used. +// +// If an error occurs during a batch, all the worker's contexts are cancelled +// and the original error is returned. +func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error { + if n < 0 { + return errors.New("cannot batch items of length < 0") + } else if n == 0 { + // Batching zero items is a noop. + return nil + } + + if batchSize < 1 { + return errors.New("cannot batch items with batch size < 1") + } + + if maxWorkers < 0 { + return errors.New("cannot batch items with workers < 0") + } else if maxWorkers == 0 { + maxWorkers = runtime.GOMAXPROCS(0) + } + + sem := semaphore.NewWeighted(int64(maxWorkers)) + g, ctx := errgroup.WithContext(ctx) + numBatches := (n + batchSize - 1) / batchSize + for i := 0; i < numBatches; i++ { + if err := sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err) + } + + batchNum := i + g.Go(func() error { + defer sem.Release(1) + start := batchNum * batchSize + end := minimum(start+batchSize, n) + return each(ctx, batchNum, start, end) + }) + } + return g.Wait() +} diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go new file mode 100644 index 0000000..c6537b9 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go @@ -0,0 +1,164 @@ +package grpcutil + +import ( + "context" + "errors" + "io" + "sync" + "time" + + "github.com/rs/zerolog/log" + "golang.org/x/mod/semver" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/authzed/authzed-go/pkg/requestmeta" + "github.com/authzed/authzed-go/pkg/responsemeta" + "github.com/authzed/spicedb/pkg/releases" +) + +// Compile-time assertion that LogDispatchTrailers and CheckServerVersion implement the +// grpc.UnaryClientInterceptor interface. +var ( + _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(LogDispatchTrailers) + _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(CheckServerVersion) +) + +var once sync.Once + +// CheckServerVersion implements a gRPC unary interceptor that requests the server version +// from SpiceDB and, if found, compares it to the current released version. +func CheckServerVersion( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + callOpts ...grpc.CallOption, +) error { + var headerMD metadata.MD + ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestServerVersion) + err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Header(&headerMD))...) + if err != nil { + return err + } + + once.Do(func() { + version := headerMD.Get(string(responsemeta.ServerVersion)) + if len(version) == 0 { + log.Debug().Msg("error reading server version response header; it may be disabled on the server") + } else if len(version) == 1 { + currentVersion := version[0] + + // If there is a build on the version, then do not compare. + if semver.Build(currentVersion) != "" { + log.Debug().Str("this-version", currentVersion).Msg("received build version of SpiceDB") + return + } + + rctx, cancel := context.WithTimeout(ctx, time.Second*2) + defer cancel() + + state, _, release, cerr := releases.CheckIsLatestVersion(rctx, func() (string, error) { + return currentVersion, nil + }, releases.GetLatestRelease) + if cerr != nil { + log.Debug().Err(cerr).Msg("error looking up currently released version") + } else { + switch state { + case releases.UnreleasedVersion: + log.Warn().Str("version", currentVersion).Msg("not calling a released version of SpiceDB") + return + + case releases.UpdateAvailable: + log.Warn().Str("this-version", currentVersion).Str("latest-released-version", release.Version).Msgf("the version of SpiceDB being called is out of date. See: %s", release.ViewURL) + return + + case releases.UpToDate: + log.Debug().Str("latest-released-version", release.Version).Msg("the version of SpiceDB being called is the latest released version") + return + + case releases.Unknown: + log.Warn().Str("unknown-released-version", release.Version).Msg("unable to check for a new SpiceDB version") + return + + default: + panic("Unknown state for CheckAndLogRunE") + } + } + } + }) + + return nil +} + +// LogDispatchTrailers implements a gRPC unary interceptor that logs the +// dispatch metadata that is present in response trailers from SpiceDB. +func LogDispatchTrailers( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + callOpts ...grpc.CallOption, +) error { + var trailerMD metadata.MD + err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Trailer(&trailerMD))...) + outputDispatchTrailers(trailerMD) + return err +} + +func outputDispatchTrailers(trailerMD metadata.MD) { + log.Trace().Interface("trailers", trailerMD).Msg("parsed trailers") + + dispatchCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata( + trailerMD, + responsemeta.DispatchedOperationsCount, + ) + if trailerErr != nil { + log.Debug().Err(trailerErr).Msg("error reading dispatched operations trailer") + } + + cachedCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata( + trailerMD, + responsemeta.CachedOperationsCount, + ) + if trailerErr != nil { + log.Debug().Err(trailerErr).Msg("error reading cached operations trailer") + } + + log.Debug(). + Int("dispatch", dispatchCount). + Int("cached", cachedCount). + Msg("extracted response dispatch metadata") +} + +// StreamLogDispatchTrailers implements a gRPC stream interceptor that logs the +// dispatch metadata that is present in response trailers from SpiceDB. +func StreamLogDispatchTrailers( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + callOpts ...grpc.CallOption, +) (grpc.ClientStream, error) { + stream, err := streamer(ctx, desc, cc, method, callOpts...) + if err != nil { + return nil, err + } + + return &wrappedStream{stream}, nil +} + +type wrappedStream struct { + grpc.ClientStream +} + +func (w *wrappedStream) RecvMsg(m interface{}) error { + err := w.ClientStream.RecvMsg(m) + if err != nil && errors.Is(err, io.EOF) { + outputDispatchTrailers(w.Trailer()) + } + return err +} diff --git a/vendor/github.com/authzed/zed/internal/printers/debug.go b/vendor/github.com/authzed/zed/internal/printers/debug.go new file mode 100644 index 0000000..2dbb9a6 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/printers/debug.go @@ -0,0 +1,197 @@ +package printers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gookit/color" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// DisplayCheckTrace prints out the check trace found in the given debug message. +func DisplayCheckTrace(checkTrace *v1.CheckDebugTrace, tp *TreePrinter, hasError bool) { + displayCheckTrace(checkTrace, tp, hasError, map[string]struct{}{}) +} + +func displayCheckTrace(checkTrace *v1.CheckDebugTrace, tp *TreePrinter, hasError bool, encountered map[string]struct{}) { + red := color.FgRed.Render + green := color.FgGreen.Render + cyan := color.FgCyan.Render + white := color.FgWhite.Render + faint := color.FgGray.Render + magenta := color.FgMagenta.Render + yellow := color.FgYellow.Render + + orange := color.C256(166).Sprint + purple := color.C256(99).Sprint + lightgreen := color.C256(35).Sprint + caveatColor := color.C256(198).Sprint + + hasPermission := green("✓") + resourceColor := white + permissionColor := color.FgWhite.Render + + switch checkTrace.PermissionType { + case v1.CheckDebugTrace_PERMISSION_TYPE_PERMISSION: + permissionColor = lightgreen + case v1.CheckDebugTrace_PERMISSION_TYPE_RELATION: + permissionColor = orange + } + + switch checkTrace.Result { + case v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION: + switch checkTrace.CaveatEvaluationInfo.Result { + case v1.CaveatEvalInfo_RESULT_FALSE: + hasPermission = red("⨉") + resourceColor = faint + permissionColor = faint + + case v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT: + hasPermission = magenta("?") + resourceColor = faint + permissionColor = faint + } + case v1.CheckDebugTrace_PERMISSIONSHIP_NO_PERMISSION: + hasPermission = red("⨉") + resourceColor = faint + permissionColor = faint + case v1.CheckDebugTrace_PERMISSIONSHIP_UNSPECIFIED: + hasPermission = yellow("∵") + } + + additional := "" + if checkTrace.GetWasCachedResult() { + sourceKind := "" + source := checkTrace.Source + if source != "" { + parts := strings.Split(source, ":") + if len(parts) > 0 { + sourceKind = parts[0] + } + } + switch sourceKind { + case "": + additional = cyan(" (cached)") + + case "spicedb": + additional = cyan(" (cached by spicedb)") + + case "materialize": + additional = purple(" (cached by materialize)") + + default: + additional = cyan(fmt.Sprintf(" (cached by %s)", sourceKind)) + } + } else if hasError && isPartOfCycle(checkTrace, map[string]struct{}{}) { + hasPermission = orange("!") + resourceColor = white + } + + isEndOfCycle := false + if hasError { + key := cycleKey(checkTrace) + _, isEndOfCycle = encountered[key] + if isEndOfCycle { + additional = color.C256(166).Sprint(" (cycle)") + } + encountered[key] = struct{}{} + } + + timing := "" + if checkTrace.Duration != nil { + timing = fmt.Sprintf(" (%s)", checkTrace.Duration.AsDuration().String()) + } + + tp = tp.Child( + fmt.Sprintf( + "%s %s:%s %s%s%s", + hasPermission, + resourceColor(checkTrace.Resource.ObjectType), + resourceColor(checkTrace.Resource.ObjectId), + permissionColor(checkTrace.Permission), + additional, + timing, + ), + ) + + if isEndOfCycle { + return + } + + if checkTrace.GetCaveatEvaluationInfo() != nil { + indicator := "" + exprColor := color.FgWhite.Render + switch checkTrace.CaveatEvaluationInfo.Result { + case v1.CaveatEvalInfo_RESULT_FALSE: + indicator = red("⨉") + exprColor = faint + + case v1.CaveatEvalInfo_RESULT_TRUE: + indicator = green("✓") + + case v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT: + indicator = magenta("?") + } + + white := color.HEXStyle("fff") + white.SetOpts(color.Opts{color.OpItalic}) + + contextMap := checkTrace.CaveatEvaluationInfo.Context.AsMap() + caveatName := checkTrace.CaveatEvaluationInfo.CaveatName + + c := tp.Child(fmt.Sprintf("%s %s %s", indicator, exprColor(checkTrace.CaveatEvaluationInfo.Expression), caveatColor(caveatName))) + if len(contextMap) > 0 { + contextJSON, _ := json.MarshalIndent(contextMap, "", " ") + c.Child(string(contextJSON)) + } else { + if checkTrace.CaveatEvaluationInfo.Result != v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT { + c.Child(faint("(no matching context found)")) + } + } + + if checkTrace.CaveatEvaluationInfo.Result == v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT { + c.Child(fmt.Sprintf("missing context: %s", strings.Join(checkTrace.CaveatEvaluationInfo.PartialCaveatInfo.MissingRequiredContext, ", "))) + } + } + + if checkTrace.GetSubProblems() != nil { + for _, subProblem := range checkTrace.GetSubProblems().Traces { + displayCheckTrace(subProblem, tp, hasError, encountered) + } + } else if checkTrace.Result == v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION { + tp.Child(purple(fmt.Sprintf("%s:%s %s", checkTrace.Subject.Object.ObjectType, checkTrace.Subject.Object.ObjectId, checkTrace.Subject.OptionalRelation))) + } +} + +func cycleKey(checkTrace *v1.CheckDebugTrace) string { + return fmt.Sprintf("%s#%s", tuple.V1StringObjectRef(checkTrace.Resource), checkTrace.Permission) +} + +func isPartOfCycle(checkTrace *v1.CheckDebugTrace, encountered map[string]struct{}) bool { + if checkTrace.GetSubProblems() == nil { + return false + } + + encounteredCopy := make(map[string]struct{}, len(encountered)) + for k, v := range encountered { + encounteredCopy[k] = v + } + + key := cycleKey(checkTrace) + if _, ok := encounteredCopy[key]; ok { + return true + } + + encounteredCopy[key] = struct{}{} + + for _, subProblem := range checkTrace.GetSubProblems().Traces { + if isPartOfCycle(subProblem, encounteredCopy) { + return true + } + } + + return false +} diff --git a/vendor/github.com/authzed/zed/internal/printers/table.go b/vendor/github.com/authzed/zed/internal/printers/table.go new file mode 100644 index 0000000..fd6f024 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/printers/table.go @@ -0,0 +1,26 @@ +package printers + +import ( + "io" + + "github.com/olekukonko/tablewriter" +) + +// PrintTable writes an terminal-friendly table of the values to the target. +func PrintTable(target io.Writer, headers []string, rows [][]string) { + table := tablewriter.NewWriter(target) + table.SetHeader(headers) + table.SetAutoWrapText(false) + table.SetAutoFormatHeaders(true) + table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) + table.SetAlignment(tablewriter.ALIGN_LEFT) + table.SetCenterSeparator("") + table.SetColumnSeparator("") + table.SetRowSeparator("") + table.SetHeaderLine(false) + table.SetBorder(false) + table.SetTablePadding("\t") + table.SetNoWhiteSpace(true) + table.AppendBulk(rows) + table.Render() +} diff --git a/vendor/github.com/authzed/zed/internal/printers/tree.go b/vendor/github.com/authzed/zed/internal/printers/tree.go new file mode 100644 index 0000000..b787210 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/printers/tree.go @@ -0,0 +1,66 @@ +package printers + +import ( + "fmt" + + "github.com/jzelinskie/stringz" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" +) + +func prettySubject(subj *v1.SubjectReference) string { + if subj.OptionalRelation == "" { + return fmt.Sprintf( + "%s:%s", + stringz.TrimPrefixIndex(subj.Object.ObjectType, "/"), + subj.Object.ObjectId, + ) + } + return fmt.Sprintf( + "%s:%s->%s", + stringz.TrimPrefixIndex(subj.Object.ObjectType, "/"), + subj.Object.ObjectId, + subj.OptionalRelation, + ) +} + +// TreeNodeTree walks an Authzed Tree Node and creates corresponding nodes +// for a treeprinter. +func TreeNodeTree(tp *TreePrinter, treeNode *v1.PermissionRelationshipTree) { + if treeNode.ExpandedObject != nil { + tp = tp.Child(fmt.Sprintf( + "%s:%s->%s", + stringz.TrimPrefixIndex(treeNode.ExpandedObject.ObjectType, "/"), + treeNode.ExpandedObject.ObjectId, + treeNode.ExpandedRelation, + )) + } + switch typed := treeNode.TreeType.(type) { + case *v1.PermissionRelationshipTree_Intermediate: + switch typed.Intermediate.Operation { + case v1.AlgebraicSubjectSet_OPERATION_UNION: + union := tp.Child("union") + for _, child := range typed.Intermediate.Children { + TreeNodeTree(union, child) + } + case v1.AlgebraicSubjectSet_OPERATION_INTERSECTION: + intersection := tp.Child("intersection") + for _, child := range typed.Intermediate.Children { + TreeNodeTree(intersection, child) + } + case v1.AlgebraicSubjectSet_OPERATION_EXCLUSION: + exclusion := tp.Child("exclusion") + for _, child := range typed.Intermediate.Children { + TreeNodeTree(exclusion, child) + } + default: + panic("unknown expand operation") + } + case *v1.PermissionRelationshipTree_Leaf: + for _, subject := range typed.Leaf.Subjects { + tp.Child(prettySubject(subject)) + } + default: + panic("unknown TreeNode type") + } +} diff --git a/vendor/github.com/authzed/zed/internal/printers/treeprinter.go b/vendor/github.com/authzed/zed/internal/printers/treeprinter.go new file mode 100644 index 0000000..55c3830 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/printers/treeprinter.go @@ -0,0 +1,43 @@ +package printers + +import ( + "strings" + + "github.com/xlab/treeprint" + + "github.com/authzed/zed/internal/console" +) + +type TreePrinter struct { + tree treeprint.Tree +} + +func NewTreePrinter() *TreePrinter { + return &TreePrinter{} +} + +func (tp *TreePrinter) Child(val string) *TreePrinter { + if tp.tree == nil { + tp.tree = treeprint.NewWithRoot(val) + return tp + } + return &TreePrinter{tree: tp.tree.AddBranch(val)} +} + +func (tp *TreePrinter) Print() { + console.Println(tp.String()) +} + +func (tp *TreePrinter) Indented() string { + var sb strings.Builder + lines := strings.Split(tp.String(), "\n") + for _, line := range lines { + sb.WriteString(" " + line + "\n") + } + + return sb.String() +} + +func (tp *TreePrinter) String() string { + return tp.tree.String() +} diff --git a/vendor/github.com/authzed/zed/internal/storage/config.go b/vendor/github.com/authzed/zed/internal/storage/config.go new file mode 100644 index 0000000..1fca679 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/storage/config.go @@ -0,0 +1,194 @@ +package storage + +import ( + "encoding/json" + "errors" + "io/fs" + "os" + "path/filepath" + "runtime" + + "github.com/jzelinskie/stringz" +) + +const configFileName = "config.json" + +// ErrConfigNotFound is returned if there is no Config in a ConfigStore. +var ErrConfigNotFound = errors.New("config did not exist") + +// ErrTokenNotFound is returned if there is no Token in a ConfigStore. +var ErrTokenNotFound = errors.New("token does not exist") + +// Config represents the contents of a zed configuration file. +type Config struct { + Version string + CurrentToken string +} + +// ConfigStore is anything that can persistently store a Config. +type ConfigStore interface { + Get() (Config, error) + Put(Config) error + Exists() (bool, error) +} + +// TokenWithOverride returns a Token that retrieves its values from the reference Token, and has its values overridden +// any of the non-empty/non-nil values of the overrideToken. +func TokenWithOverride(overrideToken Token, referenceToken Token) (Token, error) { + insecure := referenceToken.Insecure + if overrideToken.Insecure != nil { + insecure = overrideToken.Insecure + } + + // done so that logging messages don't show nil for the resulting context + if insecure == nil { + bFalse := false + insecure = &bFalse + } + + noVerifyCA := referenceToken.NoVerifyCA + if overrideToken.NoVerifyCA != nil { + noVerifyCA = overrideToken.NoVerifyCA + } + + // done so that logging messages don't show nil for the resulting context + if noVerifyCA == nil { + bFalse := false + noVerifyCA = &bFalse + } + + caCert := referenceToken.CACert + if overrideToken.CACert != nil { + caCert = overrideToken.CACert + } + + return Token{ + Name: referenceToken.Name, + Endpoint: stringz.DefaultEmpty(overrideToken.Endpoint, referenceToken.Endpoint), + APIToken: stringz.DefaultEmpty(overrideToken.APIToken, referenceToken.APIToken), + Insecure: insecure, + NoVerifyCA: noVerifyCA, + CACert: caCert, + }, nil +} + +// CurrentToken is a convenient way to obtain the CurrentToken field from the +// current Config. +func CurrentToken(cs ConfigStore, ss SecretStore) (token Token, err error) { + cfg, err := cs.Get() + if err != nil { + return Token{}, err + } + + return GetTokenIfExists(cfg.CurrentToken, ss) +} + +// SetCurrentToken is a convenient way to set the CurrentToken field in a +// the current config. +func SetCurrentToken(name string, cs ConfigStore, ss SecretStore) error { + // Ensure the token exists + exists, err := TokenExists(name, ss) + if err != nil { + return err + } + + if !exists { + return ErrTokenNotFound + } + + cfg, err := cs.Get() + if err != nil { + if errors.Is(err, ErrConfigNotFound) { + cfg = Config{Version: "v1"} + } else { + return err + } + } + + cfg.CurrentToken = name + return cs.Put(cfg) +} + +// JSONConfigStore implements a ConfigStore that stores its Config in a JSON file at the provided ConfigPath. +type JSONConfigStore struct { + ConfigPath string +} + +// Enforce that our implementation satisfies the interface. +var _ ConfigStore = JSONConfigStore{} + +// Get parses a Config from the filesystem. +func (s JSONConfigStore) Get() (Config, error) { + cfgBytes, err := os.ReadFile(filepath.Join(s.ConfigPath, configFileName)) + if errors.Is(err, fs.ErrNotExist) { + return Config{}, ErrConfigNotFound + } else if err != nil { + return Config{}, err + } + + var cfg Config + if err := json.Unmarshal(cfgBytes, &cfg); err != nil { + return Config{}, err + } + + return cfg, nil +} + +// Put overwrites a Config on the filesystem. +func (s JSONConfigStore) Put(cfg Config) error { + if err := os.MkdirAll(s.ConfigPath, 0o774); err != nil { + return err + } + + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return err + } + + return atomicWriteFile(filepath.Join(s.ConfigPath, configFileName), cfgBytes, 0o774) +} + +func (s JSONConfigStore) Exists() (bool, error) { + if _, err := os.Stat(filepath.Join(s.ConfigPath, configFileName)); errors.Is(err, fs.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +// atomicWriteFile writes data to filename+some suffix, then renames it into +// filename. +// +// Copyright (c) 2019 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style license that can be found +// at the following URL: +// https://github.com/tailscale/tailscale/blob/main/LICENSE +func atomicWriteFile(filename string, data []byte, perm os.FileMode) (err error) { + f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") + if err != nil { + return err + } + tmpName := f.Name() + defer func() { + if err != nil { + f.Close() + os.Remove(tmpName) + } + }() + if _, err := f.Write(data); err != nil { + return err + } + if runtime.GOOS != "windows" { + if err := f.Chmod(perm); err != nil { + return err + } + } + if err := f.Sync(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return os.Rename(tmpName, filename) +} diff --git a/vendor/github.com/authzed/zed/internal/storage/secrets.go b/vendor/github.com/authzed/zed/internal/storage/secrets.go new file mode 100644 index 0000000..3436f6b --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/storage/secrets.go @@ -0,0 +1,265 @@ +package storage + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + + "github.com/99designs/keyring" + "github.com/charmbracelet/x/term" + "github.com/jzelinskie/stringz" + + "github.com/authzed/zed/internal/console" +) + +type Token struct { + Name string + Endpoint string + APIToken string + Insecure *bool + NoVerifyCA *bool + CACert []byte +} + +func (t Token) AnyValue() bool { + if t.Endpoint != "" || t.APIToken != "" || t.Insecure != nil || t.NoVerifyCA != nil || len(t.CACert) > 0 { + return true + } + + return false +} + +func (t Token) Certificate() (cert []byte, ok bool) { + if len(t.CACert) > 0 { + return t.CACert, true + } + return nil, false +} + +func (t Token) IsInsecure() bool { + return t.Insecure != nil && *t.Insecure +} + +func (t Token) HasNoVerifyCA() bool { + return t.NoVerifyCA != nil && *t.NoVerifyCA +} + +func (t Token) Redacted() string { + prefix, _ := t.SplitAPIToken() + if prefix == "" { + return "<redacted>" + } + + return stringz.Join("_", prefix, "<redacted>") +} + +func (t Token) SplitAPIToken() (prefix, secret string) { + exploded := strings.Split(t.APIToken, "_") + return strings.Join(exploded[:len(exploded)-1], "_"), exploded[len(exploded)-1] +} + +type Secrets struct { + Tokens []Token +} + +type SecretStore interface { + Get() (Secrets, error) + Put(s Secrets) error +} + +// GetTokenIfExists returns an empty token if no token exists. +func GetTokenIfExists(name string, ss SecretStore) (Token, error) { + secrets, err := ss.Get() + if err != nil { + return Token{}, err + } + + for _, token := range secrets.Tokens { + if name == token.Name { + return token, nil + } + } + + return Token{}, nil +} + +func TokenExists(name string, ss SecretStore) (bool, error) { + secrets, err := ss.Get() + if err != nil { + return false, err + } + + for _, token := range secrets.Tokens { + if name == token.Name { + return true, nil + } + } + + return false, nil +} + +func PutToken(t Token, ss SecretStore) error { + secrets, err := ss.Get() + if err != nil { + return err + } + + replaced := false + for i, token := range secrets.Tokens { + if token.Name == t.Name { + secrets.Tokens[i] = t + replaced = true + } + } + + if !replaced { + secrets.Tokens = append(secrets.Tokens, t) + } + + return ss.Put(secrets) +} + +func RemoveToken(name string, ss SecretStore) error { + secrets, err := ss.Get() + if err != nil { + return err + } + + for i, token := range secrets.Tokens { + if token.Name == name { + secrets.Tokens = append(secrets.Tokens[:i], secrets.Tokens[i+1:]...) + break + } + } + + return ss.Put(secrets) +} + +type KeychainSecretStore struct { + ConfigPath string + ring keyring.Keyring +} + +var _ SecretStore = (*KeychainSecretStore)(nil) + +const ( + svcName = "zed" + keyringEntryName = svcName + " secrets" + envRecommendation = "Setting the environment variable `ZED_KEYRING_PASSWORD` to your password will skip prompts.\n" + keyringDoesNotExistPrompt = "Keyring file does not already exist.\nEnter a new non-empty passphrase for the new keyring file: " + keyringPrompt = "Enter passphrase to unlock zed keyring: " + emptyKeyringPasswordError = "your passphrase must not be empty" +) + +func fileExists(path string) (bool, error) { + _, err := os.Stat(path) + switch { + case err == nil: + return true, nil + case os.IsNotExist(err): + return false, nil + default: + return false, err + } +} + +func promptPassword(prompt string) (string, error) { + console.Printf(prompt) + b, err := term.ReadPassword(os.Stdin.Fd()) + if err != nil { + return "", err + } + console.Printf("\n") // Clear the line after a prompt + return string(b), err +} + +func (k *KeychainSecretStore) keyring() (keyring.Keyring, error) { + if k.ring != nil { + return k.ring, nil + } + + keyringPath := filepath.Join(k.ConfigPath, "keyring.jwt") + + ring, err := keyring.Open(keyring.Config{ + ServiceName: "zed", + FileDir: keyringPath, + FilePasswordFunc: func(_ string) (string, error) { + if password, ok := os.LookupEnv("ZED_KEYRING_PASSWORD"); ok { + return password, nil + } + + // Check if this is the first run where the keyring is created. + keyringExists, err := fileExists(filepath.Join(keyringPath, keyringEntryName)) + if err != nil { + return "", err + } + if !keyringExists { + // This is the first run and we're creating a password. + passwordString, err := promptPassword(envRecommendation + keyringDoesNotExistPrompt) + if err != nil { + return "", err + } + + if len(passwordString) == 0 { + // NOTE: we enforce a non-empty keyring password to prevent + // user frustration around accidentally setting an empty + // passphrase and then not knowing what it might be. + return "", errors.New(emptyKeyringPasswordError) + } + + return passwordString, nil + } + + passwordString, err := promptPassword(envRecommendation + keyringPrompt) + if err != nil { + return "", err + } + + return passwordString, nil + }, + }) + if err != nil { + return ring, err + } + + k.ring = ring + return ring, err +} + +func (k *KeychainSecretStore) Get() (Secrets, error) { + ring, err := k.keyring() + if err != nil { + return Secrets{}, err + } + + entry, err := ring.Get(keyringEntryName) + if err != nil { + if errors.Is(err, keyring.ErrKeyNotFound) { + return Secrets{}, nil // empty is okay! + } + return Secrets{}, err + } + + var s Secrets + err = json.Unmarshal(entry.Data, &s) + return s, err +} + +func (k *KeychainSecretStore) Put(s Secrets) error { + ring, err := k.keyring() + if err != nil { + return err + } + + data, err := json.Marshal(s) + if err != nil { + return err + } + + return ring.Set(keyring.Item{ + Key: keyringEntryName, + Data: data, + }) +} |
