summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal')
-rw-r--r--vendor/github.com/authzed/zed/internal/client/client.go303
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/backup.go868
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/cmd.go187
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/context.go202
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/import.go181
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/preview.go120
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/restorer.go446
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/schema.go319
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/validate.go414
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/version.go64
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/completion.go165
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/permission.go673
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/relationship.go561
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go12
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go5
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/schema.go87
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/util.go123
-rw-r--r--vendor/github.com/authzed/zed/internal/commands/watch.go212
-rw-r--r--vendor/github.com/authzed/zed/internal/console/console.go60
-rw-r--r--vendor/github.com/authzed/zed/internal/decode/decoder.go215
-rw-r--r--vendor/github.com/authzed/zed/internal/grpcutil/batch.go68
-rw-r--r--vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go164
-rw-r--r--vendor/github.com/authzed/zed/internal/printers/debug.go197
-rw-r--r--vendor/github.com/authzed/zed/internal/printers/table.go26
-rw-r--r--vendor/github.com/authzed/zed/internal/printers/tree.go66
-rw-r--r--vendor/github.com/authzed/zed/internal/printers/treeprinter.go43
-rw-r--r--vendor/github.com/authzed/zed/internal/storage/config.go194
-rw-r--r--vendor/github.com/authzed/zed/internal/storage/secrets.go265
28 files changed, 6240 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/client/client.go b/vendor/github.com/authzed/zed/internal/client/client.go
new file mode 100644
index 0000000..17ddb11
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/client/client.go
@@ -0,0 +1,303 @@
+package client
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
+ "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry"
+ "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/mitchellh/go-homedir"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/net/proxy"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials/insecure"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/authzed-go/v1"
+ "github.com/authzed/grpcutil"
+
+ zgrpcutil "github.com/authzed/zed/internal/grpcutil"
+ "github.com/authzed/zed/internal/storage"
+)
+
+// Client defines an interface for making calls to SpiceDB API.
+type Client interface {
+ v1.SchemaServiceClient
+ v1.PermissionsServiceClient
+ v1.WatchServiceClient
+ v1.ExperimentalServiceClient
+}
+
+const (
+ defaultRetryExponentialBackoff = 100 * time.Millisecond
+ defaultMaxRetryAttemptDuration = 2 * time.Second
+ defaultRetryJitterFraction = 0.5
+ bulkImportRoute = "/authzed.api.v1.ExperimentalService/BulkImportRelationships"
+ importBulkRoute = "/authzed.api.v1.PermissionsService/ImportBulkRelationships"
+)
+
+// NewClient defines an (overridable) means of creating a new client.
+var (
+ NewClient = newClientForCurrentContext
+ NewClientForContext = newClientForContext
+)
+
+func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
+ configStore, secretStore := DefaultStorage()
+ token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
+ if err != nil {
+ return nil, err
+ }
+
+ dialOpts, err := DialOptsFromFlags(cmd, token)
+ if err != nil {
+ return nil, err
+ }
+
+ if cobrautil.MustGetString(cmd, "proxy") != "" {
+ token.Endpoint = withPassthroughTarget(token.Endpoint)
+ }
+
+ client, err := authzed.NewClientWithExperimentalAPIs(token.Endpoint, dialOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return client, err
+}
+
+func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
+ currentToken, err := storage.GetTokenIfExists(contextName, secretStore)
+ if err != nil {
+ return nil, err
+ }
+
+ token, err := GetTokenWithCLIOverride(cmd, currentToken)
+ if err != nil {
+ return nil, err
+ }
+
+ dialOpts, err := DialOptsFromFlags(cmd, token)
+ if err != nil {
+ return nil, err
+ }
+
+ if cobrautil.MustGetString(cmd, "proxy") != "" {
+ token.Endpoint = withPassthroughTarget(token.Endpoint)
+ }
+
+ return authzed.NewClient(token.Endpoint, dialOpts...)
+}
+
+// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
+func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
+ // Handle the no-config case separately
+ configExists, err := configStore.Exists()
+ if err != nil {
+ return storage.Token{}, err
+ }
+ if !configExists {
+ return GetTokenWithCLIOverride(cmd, storage.Token{})
+ }
+ token, err := storage.CurrentToken(
+ configStore,
+ secretStore,
+ )
+ if err != nil {
+ return storage.Token{}, err
+ }
+
+ return GetTokenWithCLIOverride(cmd, token)
+}
+
+// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
+// flags
+func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
+ overrideToken, err := tokenFromCli(cmd)
+ if err != nil {
+ return storage.Token{}, err
+ }
+
+ result, err := storage.TokenWithOverride(
+ overrideToken,
+ token,
+ )
+ if err != nil {
+ return storage.Token{}, err
+ }
+
+ log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
+ return result, nil
+}
+
+func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
+ certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
+ var certBytes []byte
+ var err error
+ if certPath != "" {
+ certBytes, err = os.ReadFile(certPath)
+ if err != nil {
+ return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
+ }
+ }
+
+ explicitInsecure := cmd.Flags().Changed("insecure")
+ var notSecure *bool
+ if explicitInsecure {
+ i := cobrautil.MustGetBool(cmd, "insecure")
+ notSecure = &i
+ }
+
+ explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
+ var notVerifyCA *bool
+ if explicitNoVerifyCA {
+ nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
+ notVerifyCA = &nvc
+ }
+ overrideToken := storage.Token{
+ APIToken: cobrautil.MustGetString(cmd, "token"),
+ Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
+ Insecure: notSecure,
+ NoVerifyCA: notVerifyCA,
+ CACert: certBytes,
+ }
+ return overrideToken, nil
+}
+
+// DefaultStorage returns the default configured config store and secret store.
+var DefaultStorage = defaultStorage
+
+func defaultStorage() (storage.ConfigStore, storage.SecretStore) {
+ var home string
+ if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
+ home = filepath.Join(xdg, "zed")
+ } else {
+ hmdir, _ := homedir.Dir()
+ home = filepath.Join(hmdir, ".zed")
+ }
+ return &storage.JSONConfigStore{ConfigPath: home},
+ &storage.KeychainSecretStore{ConfigPath: home}
+}
+
+func certOption(token storage.Token) (opt grpc.DialOption, err error) {
+ verification := grpcutil.VerifyCA
+ if token.HasNoVerifyCA() {
+ verification = grpcutil.SkipVerifyCA
+ }
+
+ if certBytes, ok := token.Certificate(); ok {
+ return grpcutil.WithCustomCertBytes(verification, certBytes)
+ }
+
+ return grpcutil.WithSystemCerts(verification)
+}
+
+func isNoneOf(routes ...string) func(_ context.Context, c interceptors.CallMeta) bool {
+ return func(_ context.Context, c interceptors.CallMeta) bool {
+ for _, route := range routes {
+ if route == c.FullMethod() {
+ return false
+ }
+ }
+ return true
+ }
+}
+
+// DialOptsFromFlags returns the dial options from the CLI-specified flags.
+func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) {
+ maxRetries := cobrautil.MustGetUint(cmd, "max-retries")
+ retryOpts := []retry.CallOption{
+ retry.WithBackoff(retry.BackoffExponentialWithJitterBounded(defaultRetryExponentialBackoff,
+ defaultRetryJitterFraction, defaultMaxRetryAttemptDuration)),
+ retry.WithCodes(codes.ResourceExhausted, codes.Unavailable, codes.Aborted, codes.Unknown, codes.Internal),
+ retry.WithMax(maxRetries),
+ retry.WithOnRetryCallback(func(_ context.Context, attempt uint, err error) {
+ log.Error().Err(err).Uint("attempt", attempt).Msg("retrying gRPC call")
+ }),
+ }
+ unaryInterceptors := []grpc.UnaryClientInterceptor{
+ zgrpcutil.LogDispatchTrailers,
+ retry.UnaryClientInterceptor(retryOpts...),
+ }
+
+ streamInterceptors := []grpc.StreamClientInterceptor{
+ zgrpcutil.StreamLogDispatchTrailers,
+ selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(bulkImportRoute, importBulkRoute))),
+ }
+
+ if !cobrautil.MustGetBool(cmd, "skip-version-check") {
+ unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion)
+ }
+
+ opts := []grpc.DialOption{
+ grpc.WithChainUnaryInterceptor(unaryInterceptors...),
+ grpc.WithChainStreamInterceptor(streamInterceptors...),
+ }
+
+ proxyAddr := cobrautil.MustGetString(cmd, "proxy")
+
+ if proxyAddr != "" {
+ addr, err := url.Parse(proxyAddr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse socks5 proxy addr: %w", err)
+ }
+
+ dialer, err := proxy.SOCKS5("tcp", addr.Host, nil, proxy.Direct)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create socks5 proxy dialer: %w", err)
+ }
+
+ opts = append(opts, grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) {
+ return dialer.Dial("tcp", addr)
+ }))
+ }
+
+ if token.IsInsecure() {
+ opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
+ opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
+ } else {
+ opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
+ certOpt, err := certOption(token)
+ if err != nil {
+ return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
+ }
+ opts = append(opts, certOpt)
+ }
+
+ hostnameOverride := cobrautil.MustGetString(cmd, "hostname-override")
+ if hostnameOverride != "" {
+ opts = append(opts, grpc.WithAuthority(hostnameOverride))
+ }
+
+ maxMessageSize := cobrautil.MustGetInt(cmd, "max-message-size")
+ if maxMessageSize != 0 {
+ opts = append(opts, grpc.WithDefaultCallOptions(
+ // The default max client message size is 4mb.
+ // It's conceivable that a sufficiently complex
+ // schema will easily surpass this, so we set the
+ // limit higher here.
+ grpc.MaxCallRecvMsgSize(maxMessageSize),
+ grpc.MaxCallSendMsgSize(maxMessageSize),
+ ))
+ }
+
+ return opts, nil
+}
+
+func withPassthroughTarget(endpoint string) string {
+ // If it already has a scheme, return as-is
+ if strings.Contains(endpoint, "://") {
+ return endpoint
+ }
+ return "passthrough:///" + endpoint
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/backup.go b/vendor/github.com/authzed/zed/internal/cmd/backup.go
new file mode 100644
index 0000000..57fa83c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/backup.go
@@ -0,0 +1,868 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/mattn/go-isatty"
+ "github.com/rodaine/table"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ schemapkg "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/pkg/backupformat"
+)
+
+const (
+ returnIfExists = true
+ doNotReturnIfExists = false
+)
+
+// cobraRunEFunc is the signature of a cobra.Command.RunE function.
+type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error)
+
+// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic.
+func withErrorHandling(f cobraRunEFunc) cobraRunEFunc {
+ return func(cmd *cobra.Command, args []string) (err error) {
+ return addSizeErrInfo(f(cmd, args))
+ }
+}
+
+var (
+ backupCmd = &cobra.Command{
+ Use: "backup <filename>",
+ Short: "Create, restore, and inspect permissions system backups",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ // Create used to be on the root, so add it here for back-compat.
+ RunE: withErrorHandling(backupCreateCmdFunc),
+ }
+
+ backupCreateCmd = &cobra.Command{
+ Use: "create <filename>",
+ Short: "Backup a permission system to a file",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: withErrorHandling(backupCreateCmdFunc),
+ }
+
+ backupRestoreCmd = &cobra.Command{
+ Use: "restore <filename>",
+ Short: "Restore a permission system from a file",
+ Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)),
+ RunE: backupRestoreCmdFunc,
+ }
+
+ backupParseSchemaCmd = &cobra.Command{
+ Use: "parse-schema <filename>",
+ Short: "Extract the schema from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseSchemaCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupParseRevisionCmd = &cobra.Command{
+ Use: "parse-revision <filename>",
+ Short: "Extract the revision from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseRevisionCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupParseRelsCmd = &cobra.Command{
+ Use: "parse-relationships <filename>",
+ Short: "Extract the relationships from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseRelsCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupRedactCmd = &cobra.Command{
+ Use: "redact <filename>",
+ Short: "Redact a backup file to remove sensitive information",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupRedactCmdFunc(cmd, args)
+ },
+ }
+)
+
+func registerBackupCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(backupCmd)
+ registerBackupCreateFlags(backupCmd)
+
+ backupCmd.AddCommand(backupCreateCmd)
+ registerBackupCreateFlags(backupCreateCmd)
+
+ backupCmd.AddCommand(backupRestoreCmd)
+ registerBackupRestoreFlags(backupRestoreCmd)
+
+ backupCmd.AddCommand(backupRedactCmd)
+ backupRedactCmd.Flags().Bool("redact-definitions", true, "redact definitions")
+ backupRedactCmd.Flags().Bool("redact-relations", true, "redact relations")
+ backupRedactCmd.Flags().Bool("redact-object-ids", true, "redact object IDs")
+ backupRedactCmd.Flags().Bool("print-redacted-object-ids", false, "prints the redacted object IDs")
+
+ // Restore used to be on the root, so add it there too, but hidden.
+ restoreCmd := &cobra.Command{
+ Use: "restore <filename>",
+ Short: "Restore a permission system from a backup file",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: backupRestoreCmdFunc,
+ Hidden: true,
+ }
+ rootCmd.AddCommand(restoreCmd)
+ registerBackupRestoreFlags(restoreCmd)
+
+ backupCmd.AddCommand(backupParseSchemaCmd)
+ backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ backupParseSchemaCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+
+ backupCmd.AddCommand(backupParseRevisionCmd)
+ backupCmd.AddCommand(backupParseRelsCmd)
+ backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix")
+}
+
+func registerBackupRestoreFlags(cmd *cobra.Command) {
+ cmd.Flags().Uint("batch-size", 1_000, "restore relationship write batch size")
+ cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction")
+ cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch")
+ cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)")
+ cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+ cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore")
+}
+
+func registerBackupCreateFlags(cmd *cobra.Command) {
+ cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+ cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup")
+}
+
+func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) {
+ if filename == "-" {
+ log.Trace().Str("filename", "- (stdout)").Send()
+ return os.Stdout, false, nil
+ }
+
+ log.Trace().Str("filename", filename).Send()
+
+ if _, err := os.Stat(filename); err == nil {
+ if !returnIfExists {
+ return nil, false, fmt.Errorf("backup file already exists: %s", filename)
+ }
+
+ f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644)
+ if err != nil {
+ return nil, false, fmt.Errorf("unable to open existing backup file: %w", err)
+ }
+
+ return f, true, nil
+ }
+
+ f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
+ if err != nil {
+ return nil, false, fmt.Errorf("unable to create backup file: %w", err)
+ }
+
+ return f, false, nil
+}
+
+var (
+ missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`)
+ shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`)
+)
+
+func partialPrefixMatch(name, prefix string) bool {
+ return strings.HasPrefix(name, prefix+"/")
+}
+
+func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) {
+ if schema == "" || prefix == "" {
+ return schema, nil
+ }
+
+ compiledSchema, err := compiler.Compile(
+ compiler.InputSchema{Source: "schema", SchemaString: schema},
+ compiler.AllowUnprefixedObjectType(),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", fmt.Errorf("error reading schema: %w", err)
+ }
+
+ var prefixedDefs []compiler.SchemaDefinition
+ for _, def := range compiledSchema.ObjectDefinitions {
+ if partialPrefixMatch(def.Name, prefix) {
+ prefixedDefs = append(prefixedDefs, def)
+ }
+ }
+ for _, def := range compiledSchema.CaveatDefinitions {
+ if partialPrefixMatch(def.Name, prefix) {
+ prefixedDefs = append(prefixedDefs, def)
+ }
+ }
+
+ if len(prefixedDefs) == 0 {
+ return "", errors.New("filtered all definitions from schema")
+ }
+
+ filteredSchema, _, err = generator.GenerateSchema(prefixedDefs)
+ if err != nil {
+ return "", fmt.Errorf("error generating filtered schema: %w", err)
+ }
+
+ // Validate that the type system for the generated schema is comprehensive.
+ compiledFilteredSchema, err := compiler.Compile(
+ compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+
+ for _, rawDef := range compiledFilteredSchema.ObjectDefinitions {
+ ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema))
+ def, err := schemapkg.NewDefinition(ts, rawDef)
+ if err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+ if _, err := def.Validate(context.Background()); err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+ }
+
+ return
+}
+
+func hasRelPrefix(rel *v1.Relationship, prefix string) bool {
+ // Skip any relationships without the prefix on both sides.
+ return strings.HasPrefix(rel.Resource.ObjectType, prefix) &&
+ strings.HasPrefix(rel.Subject.Object.ObjectType, prefix)
+}
+
+func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+ pageLimit := cobrautil.MustGetUint32(cmd, "page-limit")
+
+ backupFileName, err := computeBackupFileName(cmd, args)
+ if err != nil {
+ return err
+ }
+
+ backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) {
+ *e = errors.Join(*e, backupFile.Sync())
+ *e = errors.Join(*e, backupFile.Close())
+ }(&err)
+
+ // the goal of this file is to keep the bulk export cursor in case the process is terminated
+ // and we need to resume from where we left off. OCF does not support in-place record updates.
+ progressFile, cursor, err := openProgressFile(backupFileName, backupExists)
+ if err != nil {
+ return err
+ }
+
+ var backupCompleted bool
+ defer func(e *error) {
+ *e = errors.Join(*e, progressFile.Sync())
+ *e = errors.Join(*e, progressFile.Close())
+
+ if backupCompleted {
+ if err := os.Remove(progressFile.Name()); err != nil {
+ log.Warn().
+ Str("progress-file", progressFile.Name()).
+ Msg("failed to remove progress file, consider removing it manually")
+ }
+ }
+ }(&err)
+
+ c, err := client.NewClient(cmd)
+ if err != nil {
+ return fmt.Errorf("unable to initialize client: %w", err)
+ }
+
+ var zedToken *v1.ZedToken
+ var encoder *backupformat.Encoder
+ if backupExists {
+ encoder, err = backupformat.NewEncoderForExisting(backupFile)
+ if err != nil {
+ return fmt.Errorf("error creating backup file encoder: %w", err)
+ }
+ } else {
+ encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile)
+ if err != nil {
+ return err
+ }
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err)
+
+ if zedToken == nil && cursor == nil {
+ return errors.New("malformed existing backup, consider recreating it")
+ }
+
+ req := &v1.BulkExportRelationshipsRequest{
+ OptionalLimit: pageLimit,
+ OptionalCursor: cursor,
+ }
+
+ // if a cursor is present, zedtoken is not needed (it is already in the cursor)
+ if zedToken != nil {
+ req.Consistency = &v1.Consistency{
+ Requirement: &v1.Consistency_AtExactSnapshot{
+ AtExactSnapshot: zedToken,
+ },
+ }
+ }
+
+ ctx := cmd.Context()
+ relationshipStream, err := c.BulkExportRelationships(ctx, req)
+ if err != nil {
+ return fmt.Errorf("error exporting relationships: %w", err)
+ }
+
+ relationshipReadStart := time.Now()
+ tick := time.Tick(5 * time.Second)
+ bar := console.CreateProgressBar("processing backup")
+ var relsFilteredOut, relsProcessed uint64
+ defer func() {
+ _ = bar.Finish()
+
+ evt := log.Info().
+ Uint64("filtered", relsFilteredOut).
+ Uint64("processed", relsProcessed).
+ Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
+ Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second))
+ if isCanceled(err) {
+ evt.Msg("backup canceled - resume by restarting the backup command")
+ } else if err != nil {
+ evt.Msg("backup failed")
+ } else {
+ evt.Msg("finished backup")
+ }
+ }()
+
+ for {
+ if err := ctx.Err(); err != nil {
+ if isCanceled(err) {
+ return context.Canceled
+ }
+
+ return fmt.Errorf("aborted backup: %w", err)
+ }
+
+ relsResp, err := relationshipStream.Recv()
+ if err != nil {
+ if isCanceled(err) {
+ return context.Canceled
+ }
+
+ if !errors.Is(err, io.EOF) {
+ return fmt.Errorf("error receiving relationships: %w", err)
+ }
+ break
+ }
+
+ for _, rel := range relsResp.Relationships {
+ if hasRelPrefix(rel, prefixFilter) {
+ if err := encoder.Append(rel); err != nil {
+ return fmt.Errorf("error storing relationship: %w", err)
+ }
+ } else {
+ relsFilteredOut++
+ }
+
+ relsProcessed++
+ if err := bar.Add(1); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+
+ // progress fallback in case there is no TTY
+ if !isatty.IsTerminal(os.Stderr.Fd()) {
+ select {
+ case <-tick:
+ log.Info().
+ Uint64("filtered", relsFilteredOut).
+ Uint64("processed", relsProcessed).
+ Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
+ Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)).
+ Msg("backup progress")
+ default:
+ }
+ }
+ }
+
+ if err := writeProgress(progressFile, relsResp); err != nil {
+ return err
+ }
+ }
+
+ backupCompleted = true
+ return nil
+}
+
+// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup
+// must be taken.
+func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) {
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+
+ schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{})
+ if err != nil {
+ return nil, nil, fmt.Errorf("error reading schema: %w", err)
+ }
+ if schemaResp.ReadAt == nil {
+ return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB")
+ }
+ schema := schemaResp.SchemaText
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+
+ if prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ zedToken := schemaResp.ReadAt
+
+ encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err)
+ }
+
+ return encoder, zedToken, nil
+}
+
+func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error {
+ err := progressFile.Truncate(0)
+ if err != nil {
+ return fmt.Errorf("unable to truncate backup progress file: %w", err)
+ }
+
+ _, err = progressFile.Seek(0, 0)
+ if err != nil {
+ return fmt.Errorf("unable to seek backup progress file: %w", err)
+ }
+
+ _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token)
+ if err != nil {
+ return fmt.Errorf("unable to write result cursor to backup progress file: %w", err)
+ }
+
+ return nil
+}
+
+// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates
+// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error.
+//
+// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume
+// backups in case of failure.
+func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) {
+ var cursor *v1.Cursor
+ progressFileName := toLockFileName(backupFileName)
+ var progressFile *os.File
+ // if a backup existed
+ var fileMode int
+ readCursor, err := os.ReadFile(progressFileName)
+ if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) {
+ return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName)
+ } else if backupAlreadyExisted && err == nil {
+ cursor = &v1.Cursor{
+ Token: string(readCursor),
+ }
+
+ // if backup existed and there is a progress marker, the latter should not be truncated to make sure the
+ // cursor stays around in case of a failure before we even start ingesting from bulk export
+ fileMode = os.O_WRONLY | os.O_CREATE
+ log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume")
+ } else {
+ // if a backup did not exist, make sure to truncate the progress file
+ fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
+ }
+
+ progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return progressFile, cursor, nil
+}
+
+func toLockFileName(backupFileName string) string {
+ return backupFileName + ".lock"
+}
+
+// computeBackupFileName computes the backup file name based.
+// If no file name is provided, it derives a backup on the current context
+func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) {
+ if len(args) > 0 {
+ return args[0], nil
+ }
+
+ configStore, secretStore := client.DefaultStorage()
+ token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
+ if err != nil {
+ return "", fmt.Errorf("failed to determine current zed context: %w", err)
+ }
+
+ ex, err := os.Executable()
+ if err != nil {
+ return "", err
+ }
+ exPath := filepath.Dir(ex)
+
+ backupFileName := filepath.Join(exPath, token.Name+".zedbackup")
+
+ return backupFileName, nil
+}
+
+func isCanceled(err error) bool {
+ if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled {
+ return true
+ }
+
+ return errors.Is(err, context.Canceled)
+}
+
+func openRestoreFile(filename string) (*os.File, int64, error) {
+ if filename == "" {
+ log.Trace().Str("filename", "(stdin)").Send()
+ return os.Stdin, -1, nil
+ }
+
+ log.Trace().Str("filename", filename).Send()
+
+ stats, err := os.Stat(filename)
+ if err != nil {
+ return nil, 0, fmt.Errorf("unable to stat restore file: %w", err)
+ }
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, 0, fmt.Errorf("unable to open restore file: %w", err)
+ }
+
+ return f, stats.Size(), nil
+}
+
+func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ if loadedToken := decoder.ZedToken(); loadedToken != nil {
+ log.Debug().Str("revision", loadedToken.Token).Msg("parsed revision")
+ }
+
+ schema := decoder.Schema()
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+ if prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return err
+ }
+ }
+ log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema")
+
+ c, err := client.NewClient(cmd)
+ if err != nil {
+ return fmt.Errorf("unable to initialize client: %w", err)
+ }
+
+ batchSize := cobrautil.MustGetUint(cmd, "batch-size")
+ batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction")
+
+ strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping)
+ if err != nil {
+ return err
+ }
+ disableRetries := cobrautil.MustGetBool(cmd, "disable-retries")
+ requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout")
+
+ return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy,
+ disableRetries, requestTimeout).restoreFromDecoder(cmd.Context())
+}
+
+// GetEnum is a helper for getting an enum value from a string cobra flag.
+func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) {
+ value := cobrautil.MustGetString(cmd, name)
+ value = strings.TrimSpace(strings.ToLower(value))
+ if enum, ok := mapping[value]; ok {
+ return enum, nil
+ }
+
+ var zeroValueE E
+ return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping))
+}
+
+func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ schema := decoder.Schema()
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+ if prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter"); prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return err
+ }
+ }
+
+ _, err = fmt.Fprintln(out, schema)
+ return err
+}
+
+func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ loadedToken := decoder.ZedToken()
+ if loadedToken == nil {
+ return fmt.Errorf("failed to parse decoded revision")
+ }
+
+ _, err = fmt.Fprintln(out, loadedToken.Token)
+ return err
+}
+
+func backupRedactCmdFunc(cmd *cobra.Command, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return fmt.Errorf("error creating restore file decoder: %w", err)
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ filename := args[0] + ".redacted"
+ writer, _, err := createBackupFile(filename, doNotReturnIfExists)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, writer.Close()) }(&err)
+
+ redactor, err := backupformat.NewRedactor(decoder, writer, backupformat.RedactionOptions{
+ RedactDefinitions: cobrautil.MustGetBool(cmd, "redact-definitions"),
+ RedactRelations: cobrautil.MustGetBool(cmd, "redact-relations"),
+ RedactObjectIDs: cobrautil.MustGetBool(cmd, "redact-object-ids"),
+ })
+ if err != nil {
+ return fmt.Errorf("error creating redactor: %w", err)
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, redactor.Close()) }(&err)
+ bar := console.CreateProgressBar("redacting backup")
+ var written int64
+ for {
+ if err := cmd.Context().Err(); err != nil {
+ return fmt.Errorf("aborted redaction: %w", err)
+ }
+
+ err := redactor.Next()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ return fmt.Errorf("error redacting: %w", err)
+ }
+
+ written++
+ if err := bar.Set64(written); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+ }
+
+ if err := bar.Finish(); err != nil {
+ return fmt.Errorf("error finalizing progress bar: %w", err)
+ }
+
+ fmt.Println("Redaction map:")
+ fmt.Println("--------------")
+ fmt.Println()
+
+ // Draw a table of definitions, caveats and relations mapped.
+ tbl := table.New("Definition Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Definitions {
+ tbl.AddRow(k, v)
+ }
+
+ tbl.Print()
+ fmt.Println()
+
+ if len(redactor.RedactionMap().Caveats) > 0 {
+ tbl = table.New("Caveat Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Caveats {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+ }
+
+ tbl = table.New("Relation/Permission Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Relations {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+
+ if len(redactor.RedactionMap().ObjectIDs) > 0 && cobrautil.MustGetBool(cmd, "print-redacted-object-ids") {
+ tbl = table.New("Object ID", "Redacted Object ID")
+ for k, v := range redactor.RedactionMap().ObjectIDs {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+ }
+
+ return nil
+}
+
+func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
+ prefix := cobrautil.MustGetString(cmd, "prefix-filter")
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() {
+ if !hasRelPrefix(rel, prefix) {
+ continue
+ }
+
+ relString, err := tuple.V1StringRelationship(rel)
+ if err != nil {
+ return err
+ }
+
+ if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) {
+ filename := "" // Default to stdin.
+ if len(args) > 0 {
+ filename = args[0]
+ }
+
+ f, _, err := openRestoreFile(filename)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ decoder, err := backupformat.NewDecoder(f)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating restore file decoder: %w", err)
+ }
+
+ return decoder, f, nil
+}
+
+func replaceRelString(rel string) string {
+ rel = strings.Replace(rel, "@", " ", 1)
+ return strings.Replace(rel, "#", " ", 1)
+}
+
+func rewriteLegacy(schema string) string {
+ schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */")))
+ return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */")))
+}
+
+var sizeErrorRegEx = regexp.MustCompile(`received message larger than max \((\d+) vs. (\d+)\)`)
+
+func addSizeErrInfo(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ code := status.Code(err)
+ if code != codes.ResourceExhausted {
+ return err
+ }
+
+ if !strings.Contains(err.Error(), "received message larger than max") {
+ return err
+ }
+
+ matches := sizeErrorRegEx.FindStringSubmatch(err.Error())
+ if len(matches) != 3 {
+ return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err)
+ }
+
+ necessaryByteCount, atoiErr := strconv.Atoi(matches[1])
+ if atoiErr != nil {
+ return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err)
+ }
+
+ return fmt.Errorf("%w: set flag --max-message-size=%d to increase the maximum allowable size", err, 2*necessaryByteCount)
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/cmd.go b/vendor/github.com/authzed/zed/internal/cmd/cmd.go
new file mode 100644
index 0000000..a2e5332
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/cmd.go
@@ -0,0 +1,187 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "io"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/cobrautil/v2/cobrazerolog"
+ "github.com/mattn/go-isatty"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+
+ "github.com/authzed/zed/internal/commands"
+)
+
+var (
+ SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED")
+ errParsing = errors.New("parsing error")
+)
+
+func init() {
+ // NOTE: this is mostly to set up logging in the case where
+ // the command doesn't exist or the construction of the command
+ // errors out before the PersistentPreRunE setup in the below function.
+ // It helps keep log output visually consistent for a user even in
+ // exceptional cases.
+ var output io.Writer
+
+ if isatty.IsTerminal(os.Stdout.Fd()) {
+ output = zerolog.ConsoleWriter{Out: os.Stderr}
+ } else {
+ output = os.Stderr
+ }
+
+ l := zerolog.New(output).With().Timestamp().Logger()
+
+ log.Logger = l
+}
+
+var flagError = flagErrorFunc
+
+func flagErrorFunc(cmd *cobra.Command, err error) error {
+ cmd.Println(err)
+ cmd.Println(cmd.UsageString())
+ return errParsing
+}
+
+// InitialiseRootCmd This function is utilised to generate docs for zed
+func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
+ rootCmd := &cobra.Command{
+ Use: "zed",
+ Short: "SpiceDB CLI, built by AuthZed",
+ Long: "A command-line client for managing SpiceDB clusters.",
+ PersistentPreRunE: cobrautil.CommandStack(
+ zl.RunE(),
+ SyncFlagsCmdFunc,
+ commands.InjectRequestID,
+ ),
+ SilenceErrors: true,
+ SilenceUsage: true,
+ }
+ rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error {
+ return flagError(command, err)
+ })
+
+ zl.RegisterFlags(rootCmd.PersistentFlags())
+
+ rootCmd.PersistentFlags().String("endpoint", "", "spicedb gRPC API endpoint")
+ rootCmd.PersistentFlags().String("permissions-system", "", "permissions system to query")
+ rootCmd.PersistentFlags().String("hostname-override", "", "override the hostname used in the connection to the endpoint")
+ rootCmd.PersistentFlags().String("token", "", "token used to authenticate to SpiceDB")
+ rootCmd.PersistentFlags().String("certificate-path", "", "path to certificate authority used to verify secure connections")
+ rootCmd.PersistentFlags().Bool("insecure", false, "connect over a plaintext connection")
+ rootCmd.PersistentFlags().Bool("skip-version-check", false, "if true, no version check is performed against the server")
+ rootCmd.PersistentFlags().Bool("no-verify-ca", false, "do not attempt to verify the server's certificate chain and host name")
+ rootCmd.PersistentFlags().Bool("debug", false, "enable debug logging")
+ rootCmd.PersistentFlags().String("request-id", "", "optional id to send along with SpiceDB requests for tracing")
+ rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed")
+ rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address")
+ rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails")
+ _ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error.
+
+ versionCmd := &cobra.Command{
+ Use: "version",
+ Short: "Display zed and SpiceDB version information",
+ RunE: versionCmdFunc,
+ }
+ cobrautil.RegisterVersionFlags(versionCmd.Flags())
+ versionCmd.Flags().Bool("include-remote-version", true, "whether to display the version of Authzed or SpiceDB for the current context")
+ rootCmd.AddCommand(versionCmd)
+
+ // Register root-level aliases
+ rootCmd.AddCommand(&cobra.Command{
+ Use: "use <context>",
+ Short: "Alias for `zed context use`",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: contextUseCmdFunc,
+ ValidArgsFunction: ContextGet,
+ })
+
+ // Register CLI-only commands.
+ registerContextCmd(rootCmd)
+ registerImportCmd(rootCmd)
+ registerValidateCmd(rootCmd)
+ registerBackupCmd(rootCmd)
+ registerPreviewCmd(rootCmd)
+
+ // Register shared commands.
+ commands.RegisterPermissionCmd(rootCmd)
+
+ relCmd := commands.RegisterRelationshipCmd(rootCmd)
+
+ commands.RegisterWatchCmd(rootCmd)
+ commands.RegisterWatchRelationshipCmd(relCmd)
+
+ schemaCmd := commands.RegisterSchemaCmd(rootCmd)
+ registerAdditionalSchemaCmds(schemaCmd)
+
+ return rootCmd
+}
+
+func Run() {
+ if err := runWithoutExit(); err != nil {
+ os.Exit(1)
+ }
+}
+
+func runWithoutExit() error {
+ zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel))
+
+ rootCmd := InitialiseRootCmd(zl)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ signalChan := make(chan os.Signal, 2)
+ signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
+ defer func() {
+ signal.Stop(signalChan)
+ cancel()
+ }()
+
+ go func() {
+ select {
+ case <-signalChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+
+ return handleError(rootCmd, rootCmd.ExecuteContext(ctx))
+}
+
+func handleError(command *cobra.Command, err error) error {
+ if err == nil {
+ return nil
+ }
+ // this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately
+ // parsed. This is necessary to be able to print the proper command-specific usage
+ var findErr error
+ var cmdToExecute *cobra.Command
+ args := os.Args[1:]
+ if command.TraverseChildren {
+ cmdToExecute, _, findErr = command.Traverse(args)
+ } else {
+ cmdToExecute, _, findErr = command.Find(args)
+ }
+ if findErr != nil {
+ cmdToExecute = command
+ }
+
+ if errors.Is(err, commands.ValidationError{}) {
+ _ = flagError(cmdToExecute, err)
+ } else if err != nil && strings.Contains(err.Error(), "unknown command") {
+ _ = flagError(cmdToExecute, err)
+ } else if !errors.Is(err, errParsing) {
+ log.Err(err).Msg("terminated with errors")
+ }
+
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/context.go b/vendor/github.com/authzed/zed/internal/cmd/context.go
new file mode 100644
index 0000000..d5a0eec
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/context.go
@@ -0,0 +1,202 @@
+package cmd
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/spf13/cobra"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/printers"
+ "github.com/authzed/zed/internal/storage"
+)
+
+func registerContextCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(contextCmd)
+
+ contextCmd.AddCommand(contextListCmd)
+ contextListCmd.Flags().Bool("reveal-tokens", false, "display secrets in results")
+
+ contextCmd.AddCommand(contextSetCmd)
+ contextCmd.AddCommand(contextRemoveCmd)
+ contextCmd.AddCommand(contextUseCmd)
+}
+
+var contextCmd = &cobra.Command{
+ Use: "context <subcommand>",
+ Short: "Manage configurations for connecting to SpiceDB deployments",
+ Aliases: []string{"ctx"},
+}
+
+var contextListCmd = &cobra.Command{
+ Use: "list",
+ Short: "Lists all available contexts",
+ Aliases: []string{"ls"},
+ Args: commands.ValidationWrapper(cobra.ExactArgs(0)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: contextListCmdFunc,
+}
+
+var contextSetCmd = &cobra.Command{
+ Use: "set <name> <endpoint> <api-token>",
+ Short: "Creates or overwrite a context",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: contextSetCmdFunc,
+}
+
+var contextRemoveCmd = &cobra.Command{
+ Use: "remove <system>",
+ Short: "Removes a context",
+ Aliases: []string{"rm"},
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ ValidArgsFunction: ContextGet,
+ RunE: contextRemoveCmdFunc,
+}
+
+var contextUseCmd = &cobra.Command{
+ Use: "use <system>",
+ Short: "Sets a context as the current context",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ ValidArgsFunction: ContextGet,
+ RunE: contextUseCmdFunc,
+}
+
+func ContextGet(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
+ _, secretStore := client.DefaultStorage()
+ secrets, err := secretStore.Get()
+ if err != nil {
+ return nil, cobra.ShellCompDirectiveError
+ }
+
+ names := make([]string, 0, len(secrets.Tokens))
+ for _, token := range secrets.Tokens {
+ names = append(names, token.Name)
+ }
+
+ return names, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder
+}
+
+func contextListCmdFunc(cmd *cobra.Command, _ []string) error {
+ cfgStore, secretStore := client.DefaultStorage()
+ secrets, err := secretStore.Get()
+ if err != nil {
+ return err
+ }
+
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+
+ rows := make([][]string, 0, len(secrets.Tokens))
+ for _, token := range secrets.Tokens {
+ current := ""
+ if token.Name == cfg.CurrentToken {
+ current = " ✓ "
+ }
+ secret := token.APIToken
+ if !cobrautil.MustGetBool(cmd, "reveal-tokens") {
+ secret = token.Redacted()
+ }
+
+ var certStr string
+ if token.IsInsecure() {
+ certStr = "insecure"
+ } else if token.HasNoVerifyCA() {
+ certStr = "no-verify-ca"
+ } else if _, ok := token.Certificate(); ok {
+ certStr = "custom"
+ } else {
+ certStr = "system"
+ }
+
+ rows = append(rows, []string{
+ current,
+ token.Name,
+ token.Endpoint,
+ secret,
+ certStr,
+ })
+ }
+
+ printers.PrintTable(os.Stdout, []string{"current", "name", "endpoint", "token", "tls cert"}, rows)
+
+ return nil
+}
+
+func contextSetCmdFunc(cmd *cobra.Command, args []string) error {
+ var name, endpoint, apiToken string
+ err := stringz.Unpack(args, &name, &endpoint, &apiToken)
+ if err != nil {
+ return err
+ }
+
+ certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
+ var certBytes []byte
+ if certPath != "" {
+ certBytes, err = os.ReadFile(certPath)
+ if err != nil {
+ return fmt.Errorf("failed to read ceritficate: %w", err)
+ }
+ }
+
+ insecure := cobrautil.MustGetBool(cmd, "insecure")
+ noVerifyCA := cobrautil.MustGetBool(cmd, "no-verify-ca")
+ cfgStore, secretStore := client.DefaultStorage()
+ err = storage.PutToken(storage.Token{
+ Name: name,
+ Endpoint: stringz.DefaultEmpty(endpoint, "grpc.authzed.com:443"),
+ APIToken: apiToken,
+ Insecure: &insecure,
+ NoVerifyCA: &noVerifyCA,
+ CACert: certBytes,
+ }, secretStore)
+ if err != nil {
+ return err
+ }
+
+ return storage.SetCurrentToken(name, cfgStore, secretStore)
+}
+
+func contextRemoveCmdFunc(_ *cobra.Command, args []string) error {
+ // If the token is what's currently being used, remove it from the config.
+ cfgStore, secretStore := client.DefaultStorage()
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+
+ if cfg.CurrentToken == args[0] {
+ cfg.CurrentToken = ""
+ }
+
+ err = cfgStore.Put(cfg)
+ if err != nil {
+ return err
+ }
+
+ return storage.RemoveToken(args[0], secretStore)
+}
+
+func contextUseCmdFunc(_ *cobra.Command, args []string) error {
+ cfgStore, secretStore := client.DefaultStorage()
+ switch len(args) {
+ case 0:
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+ console.Println(cfg.CurrentToken)
+ case 1:
+ return storage.SetCurrentToken(args[0], cfgStore, secretStore)
+ default:
+ panic("cobra command did not enforce valid number of args")
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/import.go b/vendor/github.com/authzed/zed/internal/cmd/import.go
new file mode 100644
index 0000000..687d42e
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/import.go
@@ -0,0 +1,181 @@
+package cmd
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/grpcutil"
+)
+
+func registerImportCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(importCmd)
+ importCmd.Flags().Int("batch-size", 1000, "import batch size")
+ importCmd.Flags().Int("workers", 1, "number of concurrent batching workers")
+ importCmd.Flags().Bool("schema", true, "import schema")
+ importCmd.Flags().Bool("relationships", true, "import relationships")
+ importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing")
+}
+
+var importCmd = &cobra.Command{
+ Use: "import <url>",
+ Short: "Imports schema and relationships from a file or url",
+ Example: `
+ From a gist:
+ zed import https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed import https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed import https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed import https://localhost:8443/download
+
+ From a local file (with prefix):
+ zed import file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed import authzed-x7izWU8_2Gw3.yaml
+
+ Only schema:
+ zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ Only relationships:
+ zed import --schema=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ With schema definition prefix:
+ zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+`,
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: importCmdFunc,
+}
+
+func importCmdFunc(cmd *cobra.Command, args []string) error {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ u, err := url.Parse(args[0])
+ if err != nil {
+ return err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return err
+ }
+ var p validationfile.ValidationFile
+ if _, _, err := decoder(&p); err != nil {
+ return err
+ }
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "schema") {
+ if err := importSchema(cmd.Context(), client, p.Schema.Schema, prefix); err != nil {
+ return err
+ }
+ }
+
+ if cobrautil.MustGetBool(cmd, "relationships") {
+ batchSize := cobrautil.MustGetInt(cmd, "batch-size")
+ workers := cobrautil.MustGetInt(cmd, "workers")
+ if err := importRelationships(cmd.Context(), client, p.Relationships.RelationshipsString, prefix, batchSize, workers); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+func importSchema(ctx context.Context, client client.Client, schema string, definitionPrefix string) error {
+ log.Info().Msg("importing schema")
+
+ // Recompile the schema with the specified prefix.
+ schemaText, err := rewriteSchema(schema, definitionPrefix)
+ if err != nil {
+ return err
+ }
+
+ // Write the recompiled and regenerated schema.
+ request := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema")
+
+ if _, err := client.WriteSchema(ctx, request); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func importRelationships(ctx context.Context, client client.Client, relationships string, definitionPrefix string, batchSize int, workers int) error {
+ relationshipUpdates := make([]*v1.RelationshipUpdate, 0)
+ scanner := bufio.NewScanner(strings.NewReader(relationships))
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+ if strings.HasPrefix(line, "//") {
+ continue
+ }
+ rel, err := tuple.ParseV1Rel(line)
+ if err != nil {
+ return fmt.Errorf("failed to parse %s as relationship", line)
+ }
+ log.Trace().Str("line", line).Send()
+
+ // Rewrite the prefix on the references, if any.
+ if len(definitionPrefix) > 0 {
+ rel.Resource.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Resource.ObjectType)
+ rel.Subject.Object.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Subject.Object.ObjectType)
+ }
+
+ relationshipUpdates = append(relationshipUpdates, &v1.RelationshipUpdate{
+ Operation: v1.RelationshipUpdate_OPERATION_TOUCH,
+ Relationship: rel,
+ })
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_size", batchSize).
+ Int("workers", workers).
+ Int("count", len(relationshipUpdates)).
+ Msg("importing relationships")
+
+ err := grpcutil.ConcurrentBatch(ctx, len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error {
+ request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]}
+ _, err := client.WriteRelationships(ctx, request)
+ if err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_no", no).
+ Int("count", len(relationshipUpdates[start:end])).
+ Msg("imported relationships")
+ return nil
+ })
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/preview.go b/vendor/github.com/authzed/zed/internal/cmd/preview.go
new file mode 100644
index 0000000..279e310
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/preview.go
@@ -0,0 +1,120 @@
+package cmd
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/term"
+
+ newcompiler "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ newinput "github.com/authzed/spicedb/pkg/composableschemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+
+ "github.com/authzed/zed/internal/commands"
+)
+
+func registerPreviewCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(previewCmd)
+
+ previewCmd.AddCommand(schemaCmd)
+
+ schemaCmd.AddCommand(schemaCompileCmd)
+ schemaCompileCmd.Flags().String("out", "", "output filepath; omitting writes to stdout")
+}
+
+var previewCmd = &cobra.Command{
+ Use: "preview <subcommand>",
+ Short: "Experimental commands that have been made available for preview",
+}
+
+var schemaCmd = &cobra.Command{
+ Use: "schema <subcommand>",
+ Short: "Manage schema for a permissions system",
+}
+
+var schemaCompileCmd = &cobra.Command{
+ Use: "compile <file>",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB",
+ Example: `
+ Write to stdout:
+ zed preview schema compile root.zed
+ Write to an output file:
+ zed preview schema compile root.zed --out compiled.zed
+ `,
+ ValidArgsFunction: commands.FileExtensionCompletions("zed"),
+ RunE: schemaCompileCmdFunc,
+}
+
+// Compiles an input schema written in the new composable schema syntax
+// and produces it as a fully-realized schema
+func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error {
+ stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
+ if err != nil {
+ return err
+ }
+ outputFilepath := cobrautil.MustGetString(cmd, "out")
+ if outputFilepath == "" && !term.IsTerminal(stdOutFd) {
+ return fmt.Errorf("must provide stdout or output file path")
+ }
+
+ inputFilepath := args[0]
+ inputSourceFolder := filepath.Dir(inputFilepath)
+ var schemaBytes []byte
+ schemaBytes, err = os.ReadFile(inputFilepath)
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file")
+
+ if len(schemaBytes) == 0 {
+ return errors.New("attempted to compile empty schema")
+ }
+
+ compiled, err := newcompiler.Compile(newcompiler.InputSchema{
+ Source: newinput.Source(inputFilepath),
+ SchemaString: string(schemaBytes),
+ }, newcompiler.AllowUnprefixedObjectType(),
+ newcompiler.SourceFolder(inputSourceFolder))
+ if err != nil {
+ return err
+ }
+
+ // Attempt to cast one kind of OrderedDefinition to another
+ oldDefinitions := make([]compiler.SchemaDefinition, 0, len(compiled.OrderedDefinitions))
+ for _, definition := range compiled.OrderedDefinitions {
+ oldDefinition, ok := definition.(compiler.SchemaDefinition)
+ if !ok {
+ return fmt.Errorf("could not convert definition to old schemadefinition: %v", oldDefinition)
+ }
+ oldDefinitions = append(oldDefinitions, oldDefinition)
+ }
+
+ // This is where we functionally assert that the two systems are compatible
+ generated, _, err := generator.GenerateSchema(oldDefinitions)
+ if err != nil {
+ return fmt.Errorf("could not generate resulting schema: %w", err)
+ }
+
+ // Add a newline at the end for hygiene's sake
+ terminated := generated + "\n"
+
+ if outputFilepath == "" {
+ // Print to stdout
+ fmt.Print(terminated)
+ } else {
+ err = os.WriteFile(outputFilepath, []byte(terminated), 0o_600)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/restorer.go b/vendor/github.com/authzed/zed/internal/cmd/restorer.go
new file mode 100644
index 0000000..468064f
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/restorer.go
@@ -0,0 +1,446 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/cenkalti/backoff/v4"
+ "github.com/mattn/go-isatty"
+ "github.com/rs/zerolog/log"
+ "github.com/samber/lo"
+ "github.com/schollz/progressbar/v3"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/pkg/backupformat"
+)
+
+type ConflictStrategy int
+
+const (
+ Fail ConflictStrategy = iota
+ Skip
+ Touch
+
+ defaultBackoff = 50 * time.Millisecond
+ defaultMaxRetries = 10
+)
+
+var conflictStrategyMapping = map[string]ConflictStrategy{
+ "fail": Fail,
+ "skip": Skip,
+ "touch": Touch,
+}
+
+// Fallback for datastore implementations on SpiceDB < 1.29.0 not returning proper gRPC codes
+// Remove once https://github.com/authzed/spicedb/pull/1688 lands
+var (
+ txConflictCodes = []string{
+ "SQLSTATE 23505", // CockroachDB
+ "Error 1062 (23000)", // MySQL
+ }
+ retryableErrorCodes = []string{
+ "retryable error", // CockroachDB, PostgreSQL
+ "try restarting transaction", "Error 1205", // MySQL
+ }
+)
+
+type restorer struct {
+ schema string
+ decoder *backupformat.Decoder
+ client client.Client
+ prefixFilter string
+ batchSize uint
+ batchesPerTransaction uint
+ conflictStrategy ConflictStrategy
+ disableRetryErrors bool
+ bar *progressbar.ProgressBar
+
+ // stats
+ filteredOutRels uint
+ writtenRels uint
+ writtenBatches uint
+ skippedRels uint
+ skippedBatches uint
+ duplicateRels uint
+ duplicateBatches uint
+ totalRetries uint
+ requestTimeout time.Duration
+}
+
+func newRestorer(schema string, decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize uint,
+ batchesPerTransaction uint, conflictStrategy ConflictStrategy, disableRetryErrors bool,
+ requestTimeout time.Duration,
+) *restorer {
+ return &restorer{
+ decoder: decoder,
+ schema: schema,
+ client: client,
+ prefixFilter: prefixFilter,
+ requestTimeout: requestTimeout,
+ batchSize: batchSize,
+ batchesPerTransaction: batchesPerTransaction,
+ conflictStrategy: conflictStrategy,
+ disableRetryErrors: disableRetryErrors,
+ bar: console.CreateProgressBar("restoring from backup"),
+ }
+}
+
+func (r *restorer) restoreFromDecoder(ctx context.Context) error {
+ relationshipWriteStart := time.Now()
+ defer func() {
+ if err := r.bar.Finish(); err != nil {
+ log.Warn().Err(err).Msg("error finalizing progress bar")
+ }
+ }()
+
+ r.bar.Describe("restoring schema from backup")
+ if _, err := r.client.WriteSchema(ctx, &v1.WriteSchemaRequest{
+ Schema: r.schema,
+ }); err != nil {
+ return fmt.Errorf("unable to write schema: %w", err)
+ }
+
+ relationshipWriter, err := r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating writer stream: %w", err)
+ }
+
+ r.bar.Describe("restoring relationships from backup")
+ batch := make([]*v1.Relationship, 0, r.batchSize)
+ batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction)
+ for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() {
+ if err := ctx.Err(); err != nil {
+ r.bar.Describe("backup restore aborted")
+ return fmt.Errorf("aborted restore: %w", err)
+ }
+
+ if !hasRelPrefix(rel, r.prefixFilter) {
+ r.filteredOutRels++
+ continue
+ }
+
+ batch = append(batch, rel)
+
+ if uint(len(batch))%r.batchSize == 0 {
+ batchesToBeCommitted = append(batchesToBeCommitted, batch)
+ err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{
+ Relationships: batch,
+ })
+ if err != nil {
+ // It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element
+ // sent over the stream fails, we need to call recvAndClose() to get the error.
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing batches: %w", err)
+ }
+
+ // after an error
+ relationshipWriter, err = r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating new writer stream: %w", err)
+ }
+
+ batchesToBeCommitted = batchesToBeCommitted[:0]
+ batch = batch[:0]
+ continue
+ }
+
+ // The batch just sent is kept in batchesToBeCommitted, which is used for retries.
+ // Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv).
+ batch = make([]*v1.Relationship, 0, r.batchSize)
+
+ // if we've sent the maximum number of batches per transaction, proceed to commit
+ if uint(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 {
+ continue
+ }
+
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing batches: %w", err)
+ }
+
+ relationshipWriter, err = r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating new writer stream: %w", err)
+ }
+
+ batchesToBeCommitted = batchesToBeCommitted[:0]
+ }
+ }
+
+ // Write the last batch
+ if len(batch) > 0 {
+ // Since we are going to close the stream anyway after the last batch, and given the actual error
+ // is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual
+ // underlying error that caused Send() to fail. It also gives us the opportunity to retry it
+ // in case it failed.
+ batchesToBeCommitted = append(batchesToBeCommitted, batch)
+ _ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch})
+ }
+
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing last set of batches: %w", err)
+ }
+
+ r.bar.Describe("completed import")
+ if err := r.bar.Finish(); err != nil {
+ log.Warn().Err(err).Msg("error finalizing progress bar")
+ }
+
+ totalTime := time.Since(relationshipWriteStart)
+ log.Info().
+ Uint("batches", r.writtenBatches).
+ Uint("relationships_loaded", r.writtenRels).
+ Uint("relationships_skipped", r.skippedRels).
+ Uint("duplicate_relationships", r.duplicateRels).
+ Uint("relationships_filtered_out", r.filteredOutRels).
+ Uint("retried_errors", r.totalRetries).
+ Uint64("perSecond", perSec(uint64(r.writtenRels), totalTime)).
+ Stringer("duration", totalTime).
+ Msg("finished restore")
+ return nil
+}
+
+func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.ExperimentalService_BulkImportRelationshipsClient,
+ batchesToBeCommitted [][]*v1.Relationship,
+) error {
+ var numLoaded, expectedLoaded, retries uint
+ for _, b := range batchesToBeCommitted {
+ expectedLoaded += uint(len(b))
+ }
+
+ resp, err := bulkImportClient.CloseAndRecv() // transaction commit happens here
+
+ // Failure to commit transaction means the stream is closed, so it can't be reused any further
+ // The retry will be done using WriteRelationships instead of BulkImportRelationships
+ // This lets us retry with TOUCH semantics in case of failure due to duplicates
+ retryable := isRetryableError(err)
+ conflict := isAlreadyExistsError(err)
+ canceled, cancelErr := isCanceledError(ctx.Err(), err)
+ unknown := !retryable && !conflict && !canceled && err != nil
+
+ numBatches := uint(len(batchesToBeCommitted))
+
+ switch {
+ case canceled:
+ r.bar.Describe("backup restore aborted")
+ return cancelErr
+ case unknown:
+ r.bar.Describe("failed with unrecoverable error")
+ return fmt.Errorf("error finalizing write of %d batches: %w", len(batchesToBeCommitted), err)
+ case retryable && r.disableRetryErrors:
+ return err
+ case conflict && r.conflictStrategy == Skip:
+ r.skippedRels += expectedLoaded
+ r.skippedBatches += numBatches
+ r.duplicateBatches += numBatches
+ r.duplicateRels += expectedLoaded
+ r.bar.Describe("skipping conflicting batch")
+ case conflict && r.conflictStrategy == Touch:
+ r.bar.Describe("touching conflicting batch")
+ r.duplicateRels += expectedLoaded
+ r.duplicateBatches += numBatches
+ r.totalRetries++
+ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
+ if err != nil {
+ return fmt.Errorf("failed to write retried batch: %w", err)
+ }
+
+ retries++ // account for the initial attempt
+ r.writtenBatches += numBatches
+ r.writtenRels += numLoaded
+ case conflict && r.conflictStrategy == Fail:
+ r.bar.Describe("conflict detected, aborting restore")
+ return fmt.Errorf("duplicate relationships found")
+ case retryable:
+ r.bar.Describe("retrying after error")
+ r.totalRetries++
+ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
+ if err != nil {
+ return fmt.Errorf("failed to write retried batch: %w", err)
+ }
+
+ retries++ // account for the initial attempt
+ r.writtenBatches += numBatches
+ r.writtenRels += numLoaded
+ default:
+ r.bar.Describe("restoring relationships from backup")
+ r.writtenBatches += numBatches
+ }
+
+ // it was a successful transaction commit without duplicates
+ if resp != nil {
+ numLoaded, err := safecast.ToUint(resp.NumLoaded)
+ if err != nil {
+ return spiceerrors.MustBugf("could not cast numLoaded to uint")
+ }
+ r.writtenRels += numLoaded
+ if uint64(expectedLoaded) != resp.NumLoaded {
+ log.Warn().Uint64("loaded", resp.NumLoaded).Uint("expected", expectedLoaded).Msg("unexpected number of relationships loaded")
+ }
+ }
+
+ writtenAndSkipped, err := safecast.ToInt64(r.writtenRels + r.skippedRels)
+ if err != nil {
+ return fmt.Errorf("too many written and skipped rels for an int64")
+ }
+
+ if err := r.bar.Set64(writtenAndSkipped); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+
+ if !isatty.IsTerminal(os.Stderr.Fd()) {
+ log.Trace().
+ Uint("batches_written", r.writtenBatches).
+ Uint("relationships_written", r.writtenRels).
+ Uint("duplicate_batches", r.duplicateBatches).
+ Uint("duplicate_relationships", r.duplicateRels).
+ Uint("skipped_batches", r.skippedBatches).
+ Uint("skipped_relationships", r.skippedRels).
+ Uint("retries", retries).
+ Msg("restore progress")
+ }
+
+ return nil
+}
+
+// writeBatchesWithRetry writes a set of batches using touch semantics and without transactional guarantees -
+// each batch will be committed independently. If a batch fails, it will be retried up to 10 times with a backoff.
+func (r *restorer) writeBatchesWithRetry(ctx context.Context, batches [][]*v1.Relationship) (uint, uint, error) {
+ backoffInterval := backoff.NewExponentialBackOff()
+ backoffInterval.InitialInterval = defaultBackoff
+ backoffInterval.MaxInterval = 2 * time.Second
+ backoffInterval.MaxElapsedTime = 0
+ backoffInterval.Reset()
+
+ var currentRetries, totalRetries, loadedRels uint
+ for _, batch := range batches {
+ updates := lo.Map[*v1.Relationship, *v1.RelationshipUpdate](batch, func(item *v1.Relationship, _ int) *v1.RelationshipUpdate {
+ return &v1.RelationshipUpdate{
+ Relationship: item,
+ Operation: v1.RelationshipUpdate_OPERATION_TOUCH,
+ }
+ })
+
+ for {
+ cancelCtx, cancel := context.WithTimeout(ctx, r.requestTimeout)
+ _, err := r.client.WriteRelationships(cancelCtx, &v1.WriteRelationshipsRequest{Updates: updates})
+ cancel()
+
+ if isRetryableError(err) && currentRetries < defaultMaxRetries {
+ // throttle the writes so we don't overwhelm the server
+ bo := backoffInterval.NextBackOff()
+ r.bar.Describe(fmt.Sprintf("retrying write with backoff %s after error (attempt %d/%d)", bo,
+ currentRetries+1, defaultMaxRetries))
+ time.Sleep(bo)
+ currentRetries++
+ r.totalRetries++
+ totalRetries++
+ continue
+ }
+ if err != nil {
+ return 0, 0, err
+ }
+
+ currentRetries = 0
+ backoffInterval.Reset()
+ loadedRels += uint(len(batch))
+ break
+ }
+ }
+
+ return loadedRels, totalRetries, nil
+}
+
+func isAlreadyExistsError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ if isGRPCCode(err, codes.AlreadyExists) {
+ return true
+ }
+
+ return isContainsErrorString(err, txConflictCodes...)
+}
+
+func isRetryableError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ if isGRPCCode(err, codes.Unavailable, codes.DeadlineExceeded) {
+ return true
+ }
+
+ if isContainsErrorString(err, retryableErrorCodes...) {
+ return true
+ }
+
+ return errors.Is(err, context.DeadlineExceeded)
+}
+
+func isCanceledError(errs ...error) (bool, error) {
+ for _, err := range errs {
+ if err == nil {
+ continue
+ }
+
+ if errors.Is(err, context.Canceled) {
+ return true, err
+ }
+
+ if isGRPCCode(err, codes.Canceled) {
+ return true, err
+ }
+ }
+
+ return false, nil
+}
+
+func isContainsErrorString(err error, errStrings ...string) bool {
+ if err == nil {
+ return false
+ }
+
+ for _, errString := range errStrings {
+ if strings.Contains(err.Error(), errString) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func isGRPCCode(err error, codes ...codes.Code) bool {
+ if err == nil {
+ return false
+ }
+
+ if s, ok := status.FromError(err); ok {
+ for _, code := range codes {
+ if s.Code() == code {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func perSec(i uint64, d time.Duration) uint64 {
+ secs := uint64(d.Seconds())
+ if secs == 0 {
+ return i
+ }
+ return i / secs
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/schema.go b/vendor/github.com/authzed/zed/internal/cmd/schema.go
new file mode 100644
index 0000000..bbe52f3
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/schema.go
@@ -0,0 +1,319 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/term"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/diff"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+)
+
+func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
+ schemaCmd.AddCommand(schemaCopyCmd)
+ schemaCopyCmd.Flags().Bool("json", false, "output as JSON")
+ schemaCopyCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing")
+
+ schemaCmd.AddCommand(schemaWriteCmd)
+ schemaWriteCmd.Flags().Bool("json", false, "output as JSON")
+ schemaWriteCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing")
+
+ schemaCmd.AddCommand(schemaDiffCmd)
+}
+
+var schemaWriteCmd = &cobra.Command{
+ Use: "write <file?>",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ Short: "Write a schema file (.zed or stdin) to the current permissions system",
+ ValidArgsFunction: commands.FileExtensionCompletions("zed"),
+ RunE: schemaWriteCmdFunc,
+}
+
+var schemaCopyCmd = &cobra.Command{
+ Use: "copy <src context> <dest context>",
+ Short: "Copy a schema from one context into another",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
+ ValidArgsFunction: ContextGet,
+ RunE: schemaCopyCmdFunc,
+}
+
+var schemaDiffCmd = &cobra.Command{
+ Use: "diff <before file> <after file>",
+ Short: "Diff two schema files",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
+ RunE: schemaDiffCmdFunc,
+}
+
+func schemaDiffCmdFunc(_ *cobra.Command, args []string) error {
+ beforeBytes, err := os.ReadFile(args[0])
+ if err != nil {
+ return fmt.Errorf("failed to read before schema file: %w", err)
+ }
+
+ afterBytes, err := os.ReadFile(args[1])
+ if err != nil {
+ return fmt.Errorf("failed to read after schema file: %w", err)
+ }
+
+ before, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source(args[0]), SchemaString: string(beforeBytes)},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return err
+ }
+
+ after, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source(args[1]), SchemaString: string(afterBytes)},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return err
+ }
+
+ dbefore := diff.NewDiffableSchemaFromCompiledSchema(before)
+ dafter := diff.NewDiffableSchemaFromCompiledSchema(after)
+
+ schemaDiff, err := diff.DiffSchemas(dbefore, dafter, types.Default.TypeSet)
+ if err != nil {
+ return err
+ }
+
+ for _, ns := range schemaDiff.AddedNamespaces {
+ console.Printf("Added definition: %s\n", ns)
+ }
+
+ for _, ns := range schemaDiff.RemovedNamespaces {
+ console.Printf("Removed definition: %s\n", ns)
+ }
+
+ for nsName, ns := range schemaDiff.ChangedNamespaces {
+ console.Printf("Changed definition: %s\n", nsName)
+ for _, delta := range ns.Deltas() {
+ console.Printf("\t %s: %s\n", delta.Type, delta.RelationName)
+ }
+ }
+
+ for _, caveat := range schemaDiff.AddedCaveats {
+ console.Printf("Added caveat: %s\n", caveat)
+ }
+
+ for _, caveat := range schemaDiff.RemovedCaveats {
+ console.Printf("Removed caveat: %s\n", caveat)
+ }
+
+ return nil
+}
+
+func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
+ _, secretStore := client.DefaultStorage()
+ srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
+ if err != nil {
+ return err
+ }
+
+ destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
+ if err != nil {
+ return err
+ }
+
+ readRequest := &v1.ReadSchemaRequest{}
+ log.Trace().Interface("request", readRequest).Msg("requesting schema read")
+
+ readResp, err := srcClient.ReadSchema(cmd.Context(), readRequest)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to read schema")
+ }
+ log.Trace().Interface("response", readResp).Msg("read schema")
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), nil, &readResp.SchemaText)
+ if err != nil {
+ return err
+ }
+
+ schemaText, err := rewriteSchema(readResp.SchemaText, prefix)
+ if err != nil {
+ return err
+ }
+
+ writeRequest := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", writeRequest).Msg("writing schema")
+
+ resp, err := destClient.WriteSchema(cmd.Context(), writeRequest)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to write schema")
+ }
+ log.Trace().Interface("response", resp).Msg("wrote schema")
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := commands.PrettyProto(resp)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to convert schema to JSON")
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ return nil
+}
+
+func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {
+ intFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
+ if err != nil {
+ return err
+ }
+ if len(args) == 0 && term.IsTerminal(intFd) {
+ return fmt.Errorf("must provide file path or contents via stdin")
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ var schemaBytes []byte
+ switch len(args) {
+ case 1:
+ schemaBytes, err = os.ReadFile(args[0])
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file")
+ case 0:
+ schemaBytes, err = io.ReadAll(os.Stdin)
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Msg("read schema from stdin")
+ default:
+ panic("schemaWriteCmdFunc called with incorrect number of arguments")
+ }
+
+ if len(schemaBytes) == 0 {
+ return errors.New("attempted to write empty schema")
+ }
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil)
+ if err != nil {
+ return err
+ }
+
+ schemaText, err := rewriteSchema(string(schemaBytes), prefix)
+ if err != nil {
+ return err
+ }
+
+ request := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", request).Msg("writing schema")
+
+ resp, err := client.WriteSchema(cmd.Context(), request)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to write schema")
+ }
+ log.Trace().Interface("response", resp).Msg("wrote schema")
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := commands.PrettyProto(resp)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to convert schema to JSON")
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ return nil
+}
+
+// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions.
+func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, error) {
+ if definitionPrefix == "" {
+ return existingSchemaText, nil
+ }
+
+ compiled, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText},
+ compiler.ObjectTypePrefix(definitionPrefix),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", err
+ }
+
+ generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions)
+ return generated, err
+}
+
+// determinePrefixForSchema determines the prefix to be applied to a schema that will be written.
+//
+// If specifiedPrefix is non-empty, it is returned immediately.
+// If existingSchema is non-nil, it is parsed for the prefix.
+// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there.
+func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) {
+ if specifiedPrefix != "" {
+ return specifiedPrefix, nil
+ }
+
+ var schemaText string
+ if existingSchema != nil {
+ schemaText = *existingSchema
+ } else {
+ readSchemaText, err := commands.ReadSchema(ctx, client)
+ if err != nil {
+ return "", nil
+ }
+ schemaText = readSchemaText
+ }
+
+ // If there is no schema found, return the empty string.
+ if schemaText == "" {
+ return "", nil
+ }
+
+ // Otherwise, compile the schema and grab the prefixes of the namespaces defined.
+ found, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source("schema"), SchemaString: schemaText},
+ compiler.AllowUnprefixedObjectType(),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", err
+ }
+
+ foundPrefixes := make([]string, 0, len(found.OrderedDefinitions))
+ for _, def := range found.OrderedDefinitions {
+ if strings.Contains(def.GetName(), "/") {
+ parts := strings.Split(def.GetName(), "/")
+ foundPrefixes = append(foundPrefixes, parts[0])
+ } else {
+ foundPrefixes = append(foundPrefixes, "")
+ }
+ }
+
+ prefixes := stringz.Dedup(foundPrefixes)
+ if len(prefixes) == 1 {
+ prefix := prefixes[0]
+ log.Debug().Str("prefix", prefix).Msg("found schema definition prefix")
+ return prefix, nil
+ }
+
+ return "", nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go
new file mode 100644
index 0000000..ffca1c6
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go
@@ -0,0 +1,414 @@
+package cmd
+
+import (
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/muesli/termenv"
+ "github.com/spf13/cobra"
+
+ composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/development"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/printers"
+)
+
+var (
+ // NOTE: these need to be set *after* the renderer has been set, otherwise
+ // the forceColor setting can't work, hence the thunking.
+ success = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!")
+ }
+ complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") }
+ errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") }
+ warningPrefix = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ")
+ }
+ errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) }
+ linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) }
+ highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) }
+ highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) }
+ traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) }
+)
+
+func registerValidateCmd(cmd *cobra.Command) {
+ validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments")
+ validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")")
+ cmd.AddCommand(validateCmd)
+}
+
+var validateCmd = &cobra.Command{
+ Use: "validate <validation_file_or_schema_file>",
+ Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)",
+ Example: `
+ From a local file (with prefix):
+ zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed validate authzed-x7izWU8_2Gw3.yaml
+
+ From a gist:
+ zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed validate https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed validate https://localhost:8443/download`,
+ Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)),
+ ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"),
+ PreRunE: validatePreRunE,
+ RunE: func(cmd *cobra.Command, filenames []string) error {
+ result, shouldExit, err := validateCmdFunc(cmd, filenames)
+ if err != nil {
+ return err
+ }
+ console.Print(result)
+ if shouldExit {
+ os.Exit(1)
+ }
+ return nil
+ },
+
+ // A schema that causes the parser/compiler to error will halt execution
+ // of this command with an error. In that case, we want to just display the error,
+ // rather than showing usage for this command.
+ SilenceUsage: true,
+}
+
+var validSchemaTypes = []string{"", "standard", "composable"}
+
+func validatePreRunE(cmd *cobra.Command, _ []string) error {
+ // Override lipgloss's autodetection of whether it's in a terminal environment
+ // and display things in color anyway. This can be nice in CI environments that
+ // support it.
+ setForceColor := cobrautil.MustGetBool(cmd, "force-color")
+ if setForceColor {
+ lipgloss.SetColorProfile(termenv.ANSI256)
+ }
+
+ schemaType := cobrautil.MustGetString(cmd, "schema-type")
+ schemaTypeValid := false
+ for _, validType := range validSchemaTypes {
+ if schemaType == validType {
+ schemaTypeValid = true
+ }
+ }
+ if !schemaTypeValid {
+ return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType)
+ }
+
+ return nil
+}
+
+// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors.
+func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) {
+ // Initialize variables for multiple files
+ var (
+ totalFiles = len(filenames)
+ successfullyValidatedFiles = 0
+ shouldExit = false
+ toPrint = &strings.Builder{}
+ schemaType = cobrautil.MustGetString(cmd, "schema-type")
+ )
+
+ for _, filename := range filenames {
+ // If we're running over multiple files, print the filename for context/debugging purposes
+ if totalFiles > 1 {
+ toPrint.WriteString(filename + "\n")
+ }
+
+ u, err := url.Parse(filename)
+ if err != nil {
+ return "", false, err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return "", false, err
+ }
+
+ var parsed validationfile.ValidationFile
+ validateContents, isOnlySchema, err := decoder(&parsed)
+ standardErrors, composableErrs, otherErrs := classifyErrors(err)
+
+ switch schemaType {
+ case "standard":
+ if standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(standardErrors, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, standardErrors
+ }
+ case "composable":
+ if composableErrs != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ default:
+ // By default, validate will attempt to validate a schema first according to composable schema rules,
+ // then standard schema rules,
+ // and if both fail it will show the errors from composable schema.
+ if composableErrs != nil && standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ }
+
+ if otherErrs != nil {
+ return "", false, otherErrs
+ }
+
+ tuples := make([]*core.RelationTuple, 0)
+ totalAssertions := 0
+ totalRelationsValidated := 0
+
+ for _, rel := range parsed.Relationships.Relationships {
+ tuples = append(tuples, rel.ToCoreTuple())
+ }
+
+ // Create the development context for each run
+ ctx := cmd.Context()
+ devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{
+ Schema: parsed.Schema.Schema,
+ Relationships: tuples,
+ })
+ if err != nil {
+ return "", false, err
+ }
+ if devErrs != nil {
+ schemaOffset := parsed.Schema.SourcePosition.LineNumber
+ if isOnlySchema {
+ schemaOffset = 0
+ }
+
+ // Output errors
+ outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset)
+ return toPrint.String(), true, nil
+ }
+ // Run assertions
+ adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions)
+ if aerr != nil {
+ return "", false, aerr
+ }
+ if adevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, adevErrs)
+ return toPrint.String(), true, nil
+ }
+ successfullyValidatedFiles++
+
+ // Run expected relations for file
+ _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations)
+ if rerr != nil {
+ return "", false, rerr
+ }
+ if erDevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, erDevErrs)
+ return toPrint.String(), true, nil
+ }
+ // Print out any warnings for file
+ warnings, err := development.GetWarnings(ctx, devCtx)
+ if err != nil {
+ return "", false, err
+ }
+
+ if len(warnings) > 0 {
+ for _, warning := range warnings {
+ fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message)
+ outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed
+ toPrint.WriteString("\n")
+ }
+
+ toPrint.WriteString(complete())
+ } else {
+ toPrint.WriteString(success())
+ }
+ totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse)
+ totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap)
+
+ fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n",
+ len(tuples),
+ totalAssertions,
+ totalRelationsValidated)
+ }
+
+ if totalFiles > 1 {
+ fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles)
+ }
+
+ return toPrint.String(), shouldExit, nil
+}
+
+func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) {
+ fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error()))
+ outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed
+}
+
+func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) {
+ lines := strings.Split(string(validateContents), "\n")
+ // These should be fine to be zero if the cast fails.
+ intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber)
+ intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition)
+ errorLineNumber := intLineNumber - 1
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+}
+
+func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) {
+ outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0)
+}
+
+func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) {
+ lines := strings.Split(string(validateContents), "\n")
+
+ for _, devErr := range devErrors {
+ outputDeveloperError(sb, devErr, lines, lineOffset)
+ }
+}
+
+func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) {
+ fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message))
+ errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, devError.Context, errorLineNumber, -1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+
+ if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil {
+ fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:"))
+ tp := printers.NewTreePrinter()
+ printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true)
+ sb.WriteString(tp.Indented())
+ }
+
+ sb.WriteString("\n\n")
+}
+
+func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) {
+ if index < 0 || index >= len(lines) {
+ return
+ }
+
+ lineNumberLength := len(fmt.Sprintf("%d", len(lines)))
+ lineContents := strings.ReplaceAll(lines[index], "\t", " ")
+ lineDelimiter := "|"
+
+ highlightLength := max(0, len(highlight)-1)
+ highlightColumnIndex := -1
+
+ // If the highlight string was provided, then we need to find the index of the highlight
+ // string in the line contents to determine where to place the caret.
+ if len(highlight) > 0 {
+ offset := 0
+ for {
+ foundRelativeIndex := strings.Index(lineContents[offset:], highlight)
+ foundIndex := foundRelativeIndex + offset
+ if foundIndex >= highlightStartingColumnIndex {
+ highlightColumnIndex = foundIndex
+ break
+ }
+
+ offset = foundIndex + 1
+ if foundRelativeIndex < 0 || foundIndex > len(lineContents) {
+ break
+ }
+ }
+ } else if highlightStartingColumnIndex >= 0 {
+ // Otherwise, just show a caret at the specified starting column, if any.
+ highlightColumnIndex = highlightStartingColumnIndex
+ }
+
+ lineNumberStr := fmt.Sprintf("%d", index+1)
+ noNumberSpaces := strings.Repeat(" ", lineNumberLength)
+
+ lineNumberStyle := linePrefixStyle
+ lineContentsStyle := codeStyle
+ if index == highlightLineIndex {
+ lineNumberStyle = highlightedLineStyle
+ lineContentsStyle = highlightedCodeStyle
+ lineDelimiter = ">"
+ }
+
+ lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr))
+
+ if highlightColumnIndex < 0 {
+ fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents))
+ } else {
+ fmt.Fprintf(sb, " %s%s %s %s%s%s\n",
+ lineNumberSpacer,
+ lineNumberStyle().Render(lineNumberStr),
+ lineDelimiter,
+ lineContentsStyle().Render(lineContents[0:highlightColumnIndex]),
+ highlightedSourceStyle().Render(highlight),
+ lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):]))
+
+ fmt.Fprintf(sb, " %s %s %s%s%s\n",
+ noNumberSpaces,
+ lineDelimiter,
+ strings.Repeat(" ", highlightColumnIndex),
+ highlightedSourceStyle().Render("^"),
+ highlightedSourceStyle().Render(strings.Repeat("~", highlightLength)))
+ }
+}
+
+// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors.
+func classifyErrors(err error) (error, error, error) {
+ if err == nil {
+ return nil, nil, nil
+ }
+ var standardErr compiler.BaseCompilerError
+ var composableErr composable.BaseCompilerError
+ var retStandard, retComposable, allOthers error
+
+ ok := errors.As(err, &standardErr)
+ if ok {
+ retStandard = standardErr
+ }
+ ok = errors.As(err, &composableErr)
+ if ok {
+ retComposable = composableErr
+ }
+
+ if retStandard == nil && retComposable == nil {
+ allOthers = err
+ }
+
+ return retStandard, retComposable, allOthers
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/version.go b/vendor/github.com/authzed/zed/internal/cmd/version.go
new file mode 100644
index 0000000..b87ec66
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/version.go
@@ -0,0 +1,64 @@
+package cmd
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/gookit/color"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/mattn/go-isatty"
+ "github.com/spf13/cobra"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+)
+
+func versionCmdFunc(cmd *cobra.Command, _ []string) error {
+ if !isatty.IsTerminal(os.Stdout.Fd()) {
+ color.Disable()
+ }
+
+ includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
+ hasContext := false
+ if includeRemoteVersion {
+ configStore, secretStore := client.DefaultStorage()
+ _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
+ hasContext = err == nil
+ }
+
+ if hasContext && includeRemoteVersion {
+ green := color.FgGreen.Render
+ fmt.Print(green("client: "))
+ }
+
+ console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))
+
+ if hasContext && includeRemoteVersion {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ // NOTE: we ignore the error here, as it may be due to a schema not existing, or
+ // the client being unable to connect, etc. We just treat all such cases as an unknown
+ // version.
+ var headerMD metadata.MD
+ _, _ = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD))
+ version := headerMD.Get(string(responsemeta.ServerVersion))
+
+ blue := color.FgLightBlue.Render
+ fmt.Print(blue("service: "))
+ if len(version) == 1 {
+ console.Println(version[0])
+ } else {
+ console.Println("(unknown)")
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/completion.go b/vendor/github.com/authzed/zed/internal/commands/completion.go
new file mode 100644
index 0000000..dd24a74
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/completion.go
@@ -0,0 +1,165 @@
+package commands
+
+import (
+ "errors"
+ "strings"
+
+ "github.com/spf13/cobra"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+
+ "github.com/authzed/zed/internal/client"
+)
+
+type CompletionArgumentType int
+
+const (
+ ResourceType CompletionArgumentType = iota
+ ResourceID
+ Permission
+ SubjectType
+ SubjectID
+ SubjectTypeWithOptionalRelation
+)
+
+func FileExtensionCompletions(extension ...string) func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
+ return func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
+ return extension, cobra.ShellCompDirectiveFilterFileExt
+ }
+}
+
+func GetArgs(fields ...CompletionArgumentType) func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
+ return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
+ // Read the current schema, if any.
+ schema, err := readSchema(cmd)
+ if err != nil {
+ return nil, cobra.ShellCompDirectiveError
+ }
+
+ // Find the specified resource type, if any.
+ var resourceType string
+ loop:
+ for index, arg := range args {
+ field := fields[index]
+ switch field {
+ case ResourceType:
+ resourceType = arg
+ break loop
+
+ case ResourceID:
+ pieces := strings.Split(arg, ":")
+ if len(pieces) >= 1 {
+ resourceType = pieces[0]
+ break loop
+ }
+ }
+ }
+
+ // Handle : on resource and subject IDs.
+ if strings.HasSuffix(toComplete, ":") && (fields[len(args)] == ResourceID || fields[len(args)] == SubjectID) {
+ comps := []string{}
+ comps = cobra.AppendActiveHelp(comps, "Please enter an object ID")
+ return comps, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ // Handle # on subject types. If the toComplete contains a valid subject,
+ // then we should return the relation names. Note that we cannot do this
+ // on the # character because shell autocompletion won't send it to us.
+ if len(args) == len(fields)-1 && toComplete != "" && fields[len(args)] == SubjectTypeWithOptionalRelation {
+ for _, objDef := range schema.ObjectDefinitions {
+ subjectType := toComplete
+ if objDef.Name == subjectType {
+ relationNames := make([]string, 0)
+ relationNames = append(relationNames, subjectType)
+ for _, relation := range objDef.Relation {
+ relationNames = append(relationNames, subjectType+"#"+relation.Name)
+ }
+ return relationNames, cobra.ShellCompDirectiveNoFileComp
+ }
+ }
+ }
+
+ if len(args) >= len(fields) {
+ // If we have all the arguments, return no completions.
+ return nil, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ // Return the completions.
+ currentFieldType := fields[len(args)]
+ switch currentFieldType {
+ case ResourceType:
+ fallthrough
+
+ case SubjectType:
+ fallthrough
+
+ case SubjectID:
+ fallthrough
+
+ case SubjectTypeWithOptionalRelation:
+ fallthrough
+
+ case ResourceID:
+ resourceTypeNames := make([]string, 0, len(schema.ObjectDefinitions))
+ for _, objDef := range schema.ObjectDefinitions {
+ resourceTypeNames = append(resourceTypeNames, objDef.Name)
+ }
+
+ flags := cobra.ShellCompDirectiveNoFileComp
+ if currentFieldType == ResourceID || currentFieldType == SubjectID || currentFieldType == SubjectTypeWithOptionalRelation {
+ flags |= cobra.ShellCompDirectiveNoSpace
+ }
+
+ return resourceTypeNames, flags
+
+ case Permission:
+ if resourceType == "" {
+ return nil, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ relationNames := make([]string, 0)
+ for _, objDef := range schema.ObjectDefinitions {
+ if objDef.Name == resourceType {
+ for _, relation := range objDef.Relation {
+ relationNames = append(relationNames, relation.Name)
+ }
+ }
+ }
+ return relationNames, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ return nil, cobra.ShellCompDirectiveDefault
+ }
+}
+
+func readSchema(cmd *cobra.Command) (*compiler.CompiledSchema, error) {
+ // TODO: we should find a way to cache this
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return nil, err
+ }
+
+ request := &v1.ReadSchemaRequest{}
+
+ resp, err := client.ReadSchema(cmd.Context(), request)
+ if err != nil {
+ return nil, err
+ }
+
+ schemaText := resp.SchemaText
+ if len(schemaText) == 0 {
+ return nil, errors.New("no schema defined")
+ }
+
+ compiledSchema, err := compiler.Compile(
+ compiler.InputSchema{Source: "schema", SchemaString: schemaText},
+ compiler.AllowUnprefixedObjectType(),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return compiledSchema, nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/permission.go b/vendor/github.com/authzed/zed/internal/commands/permission.go
new file mode 100644
index 0000000..58b022c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/permission.go
@@ -0,0 +1,673 @@
+package commands
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "github.com/spf13/pflag"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/protobuf/encoding/protojson"
+ "google.golang.org/protobuf/encoding/prototext"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/printers"
+)
+
+var ErrMultipleConsistencies = errors.New("provided multiple consistency flags")
+
+func registerConsistencyFlags(flags *pflag.FlagSet) {
+ flags.String("consistency-at-exactly", "", "evaluate at the provided zedtoken")
+ flags.String("consistency-at-least", "", "evaluate at least as consistent as the provided zedtoken")
+ flags.Bool("consistency-min-latency", false, "evaluate at the zedtoken preferred by the database")
+ flags.Bool("consistency-full", false, "evaluate at the newest zedtoken in the database")
+}
+
+func consistencyFromCmd(cmd *cobra.Command) (c *v1.Consistency, err error) {
+ if cobrautil.MustGetBool(cmd, "consistency-full") {
+ c = &v1.Consistency{Requirement: &v1.Consistency_FullyConsistent{FullyConsistent: true}}
+ }
+ if atLeast := cobrautil.MustGetStringExpanded(cmd, "consistency-at-least"); atLeast != "" {
+ if c != nil {
+ return nil, ErrMultipleConsistencies
+ }
+ c = &v1.Consistency{Requirement: &v1.Consistency_AtLeastAsFresh{AtLeastAsFresh: &v1.ZedToken{Token: atLeast}}}
+ }
+
+ // Deprecated (hidden) flag.
+ if revision := cobrautil.MustGetStringExpanded(cmd, "revision"); revision != "" {
+ if c != nil {
+ return nil, ErrMultipleConsistencies
+ }
+ c = &v1.Consistency{Requirement: &v1.Consistency_AtLeastAsFresh{AtLeastAsFresh: &v1.ZedToken{Token: revision}}}
+ }
+
+ if exact := cobrautil.MustGetStringExpanded(cmd, "consistency-at-exactly"); exact != "" {
+ if c != nil {
+ return nil, ErrMultipleConsistencies
+ }
+ c = &v1.Consistency{Requirement: &v1.Consistency_AtExactSnapshot{AtExactSnapshot: &v1.ZedToken{Token: exact}}}
+ }
+
+ if c == nil {
+ c = &v1.Consistency{Requirement: &v1.Consistency_MinimizeLatency{MinimizeLatency: true}}
+ }
+ return
+}
+
+func RegisterPermissionCmd(rootCmd *cobra.Command) *cobra.Command {
+ rootCmd.AddCommand(permissionCmd)
+
+ permissionCmd.AddCommand(checkCmd)
+ checkCmd.Flags().Bool("json", false, "output as JSON")
+ checkCmd.Flags().String("revision", "", "optional revision at which to check")
+ _ = checkCmd.Flags().MarkHidden("revision")
+ checkCmd.Flags().Bool("explain", false, "requests debug information from SpiceDB and prints out a trace of the requests")
+ checkCmd.Flags().Bool("schema", false, "requests debug information from SpiceDB and prints out the schema used")
+ checkCmd.Flags().Bool("error-on-no-permission", false, "if true, zed will return exit code 1 if subject does not have unconditional permission")
+ checkCmd.Flags().String("caveat-context", "", "the caveat context to send along with the check, in JSON form")
+ registerConsistencyFlags(checkCmd.Flags())
+
+ permissionCmd.AddCommand(checkBulkCmd)
+ checkBulkCmd.Flags().String("revision", "", "optional revision at which to check")
+ checkBulkCmd.Flags().Bool("json", false, "output as JSON")
+ checkBulkCmd.Flags().Bool("explain", false, "requests debug information from SpiceDB and prints out a trace of the requests")
+ checkBulkCmd.Flags().Bool("schema", false, "requests debug information from SpiceDB and prints out the schema used")
+ registerConsistencyFlags(checkBulkCmd.Flags())
+
+ permissionCmd.AddCommand(expandCmd)
+ expandCmd.Flags().Bool("json", false, "output as JSON")
+ expandCmd.Flags().String("revision", "", "optional revision at which to check")
+ registerConsistencyFlags(expandCmd.Flags())
+
+ // NOTE: `lookup` is an alias of `lookup-resources` (below)
+ // and must have the same list of flags in order for it to work.
+ permissionCmd.AddCommand(lookupCmd)
+ lookupCmd.Flags().Bool("json", false, "output as JSON")
+ lookupCmd.Flags().String("revision", "", "optional revision at which to check")
+ lookupCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form")
+ lookupCmd.Flags().Uint32("page-limit", 0, "limit of relations returned per page")
+ registerConsistencyFlags(lookupCmd.Flags())
+
+ permissionCmd.AddCommand(lookupResourcesCmd)
+ lookupResourcesCmd.Flags().Bool("json", false, "output as JSON")
+ lookupResourcesCmd.Flags().String("revision", "", "optional revision at which to check")
+ lookupResourcesCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form")
+ lookupResourcesCmd.Flags().Uint32("page-limit", 0, "limit of relations returned per page")
+ lookupResourcesCmd.Flags().String("cursor", "", "resume pagination from a specific cursor token")
+ lookupResourcesCmd.Flags().Bool("show-cursor", true, "display the cursor token after pagination")
+ registerConsistencyFlags(lookupResourcesCmd.Flags())
+
+ permissionCmd.AddCommand(lookupSubjectsCmd)
+ lookupSubjectsCmd.Flags().Bool("json", false, "output as JSON")
+ lookupSubjectsCmd.Flags().String("revision", "", "optional revision at which to check")
+ lookupSubjectsCmd.Flags().String("caveat-context", "", "the caveat context to send along with the lookup, in JSON form")
+ registerConsistencyFlags(lookupSubjectsCmd.Flags())
+
+ return permissionCmd
+}
+
+var permissionCmd = &cobra.Command{
+ Use: "permission <subcommand>",
+ Short: "Query the permissions in a permissions system",
+ Aliases: []string{"perm"},
+}
+
+var checkBulkCmd = &cobra.Command{
+ Use: "bulk <resource:id#permission@subject:id> <resource:id#permission@subject:id> ...",
+ Short: "Check a permissions in bulk exists for a resource-subject pairs",
+ Args: ValidationWrapper(cobra.MinimumNArgs(1)),
+ RunE: checkBulkCmdFunc,
+}
+
+var checkCmd = &cobra.Command{
+ Use: "check <resource:id> <permission> <subject:id>",
+ Short: "Check that a permission exists for a subject",
+ Args: ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectID),
+ RunE: checkCmdFunc,
+}
+
+var expandCmd = &cobra.Command{
+ Use: "expand <permission> <resource:id>",
+ Short: "Expand the structure of a permission",
+ Args: ValidationWrapper(cobra.ExactArgs(2)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: expandCmdFunc,
+}
+
+var lookupResourcesCmd = &cobra.Command{
+ Use: "lookup-resources <type> <permission> <subject:id>",
+ Short: "Enumerates resources of a given type for which the subject has permission",
+ Args: ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID),
+ RunE: lookupResourcesCmdFunc,
+}
+
+var lookupCmd = &cobra.Command{
+ Use: "lookup <type> <permission> <subject:id>",
+ Short: "Enumerates the resources of a given type for which the subject has permission",
+ Args: ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID),
+ RunE: lookupResourcesCmdFunc,
+ Deprecated: "prefer lookup-resources",
+ Hidden: true,
+}
+
+var lookupSubjectsCmd = &cobra.Command{
+ Use: "lookup-subjects <resource:id> <permission> <subject_type#optional_subject_relation>",
+ Short: "Enumerates the subjects of a given type for which the subject has permission on the resource",
+ Args: ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: lookupSubjectsCmdFunc,
+}
+
+func checkCmdFunc(cmd *cobra.Command, args []string) error {
+ var objectNS, objectID string
+ err := stringz.SplitExact(args[0], ":", &objectNS, &objectID)
+ if err != nil {
+ return err
+ }
+
+ relation := args[1]
+
+ subjectNS, subjectID, subjectRel, err := ParseSubject(args[2])
+ if err != nil {
+ return err
+ }
+
+ caveatContext, err := GetCaveatContext(cmd)
+ if err != nil {
+ return err
+ }
+
+ consistency, err := consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ request := &v1.CheckPermissionRequest{
+ Resource: &v1.ObjectReference{
+ ObjectType: objectNS,
+ ObjectId: objectID,
+ },
+ Permission: relation,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: subjectNS,
+ ObjectId: subjectID,
+ },
+ OptionalRelation: subjectRel,
+ },
+ Context: caveatContext,
+ Consistency: consistency,
+ }
+ log.Trace().Interface("request", request).Send()
+
+ ctx := cmd.Context()
+ if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") {
+ log.Info().Msg("debugging requested on check")
+ ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestDebugInformation)
+ request.WithTracing = true
+ }
+
+ var trailerMD metadata.MD
+ resp, err := client.CheckPermission(ctx, request, grpc.Trailer(&trailerMD))
+ if err != nil {
+ var debugInfo *v1.DebugInformation
+
+ // Check for the debug trace contained in the error details.
+ if errInfo, ok := grpcErrorInfoFrom(err); ok {
+ if encodedDebugInfo, ok := errInfo.Metadata["debug_trace_proto_text"]; ok {
+ debugInfo = &v1.DebugInformation{}
+ if uerr := prototext.Unmarshal([]byte(encodedDebugInfo), debugInfo); uerr != nil {
+ return uerr
+ }
+ }
+ }
+
+ derr := displayDebugInformationIfRequested(cmd, debugInfo, trailerMD, true)
+ if derr != nil {
+ return derr
+ }
+
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ switch resp.Permissionship {
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION:
+ log.Warn().Strs("fields", resp.PartialCaveatInfo.MissingRequiredContext).Msg("missing fields in caveat context")
+ console.Println("caveated")
+
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION:
+ console.Println("true")
+
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION:
+ console.Println("false")
+
+ default:
+ return fmt.Errorf("unknown permission response: %v", resp.Permissionship)
+ }
+
+ err = displayDebugInformationIfRequested(cmd, resp.DebugTrace, trailerMD, false)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "error-on-no-permission") {
+ if resp.Permissionship != v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
+ os.Exit(1)
+ }
+ }
+
+ return nil
+}
+
+func checkBulkCmdFunc(cmd *cobra.Command, args []string) error {
+ items := make([]*v1.CheckBulkPermissionsRequestItem, 0, len(args))
+ for _, arg := range args {
+ rel, err := tuple.ParseV1Rel(arg)
+ if err != nil {
+ return fmt.Errorf("unable to parse relation: %s", arg)
+ }
+
+ item := &v1.CheckBulkPermissionsRequestItem{
+ Resource: &v1.ObjectReference{
+ ObjectType: rel.Resource.ObjectType,
+ ObjectId: rel.Resource.ObjectId,
+ },
+ Permission: rel.Relation,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: rel.Subject.Object.ObjectType,
+ ObjectId: rel.Subject.Object.ObjectId,
+ },
+ },
+ }
+ if rel.OptionalCaveat != nil {
+ item.Context = rel.OptionalCaveat.Context
+ }
+ items = append(items, item)
+ }
+
+ consistency, err := consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ bulk := &v1.CheckBulkPermissionsRequest{
+ Consistency: consistency,
+ Items: items,
+ }
+
+ log.Trace().Interface("request", bulk).Send()
+
+ ctx := cmd.Context()
+ c, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") {
+ bulk.WithTracing = true
+ }
+
+ resp, err := c.CheckBulkPermissions(ctx, bulk)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ for _, item := range resp.Pairs {
+ console.Printf("%s:%s#%s@%s:%s => ",
+ item.Request.Resource.ObjectType, item.Request.Resource.ObjectId, item.Request.Permission, item.Request.Subject.Object.ObjectType, item.Request.Subject.Object.ObjectId)
+
+ switch responseType := item.Response.(type) {
+ case *v1.CheckBulkPermissionsPair_Item:
+ switch responseType.Item.Permissionship {
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION:
+ console.Println("caveated")
+
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION:
+ console.Println("true")
+
+ case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION:
+ console.Println("false")
+ }
+
+ err = displayDebugInformationIfRequested(cmd, responseType.Item.DebugTrace, nil, false)
+ if err != nil {
+ return err
+ }
+
+ case *v1.CheckBulkPermissionsPair_Error:
+ console.Println(fmt.Sprintf("error: %s", responseType.Error))
+ }
+ }
+
+ return nil
+}
+
+func expandCmdFunc(cmd *cobra.Command, args []string) error {
+ relation := args[0]
+
+ var objectNS, objectID string
+ err := stringz.SplitExact(args[1], ":", &objectNS, &objectID)
+ if err != nil {
+ return err
+ }
+
+ consistency, err := consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ request := &v1.ExpandPermissionTreeRequest{
+ Resource: &v1.ObjectReference{
+ ObjectType: objectNS,
+ ObjectId: objectID,
+ },
+ Permission: relation,
+ Consistency: consistency,
+ }
+ log.Trace().Interface("request", request).Send()
+
+ resp, err := client.ExpandPermissionTree(cmd.Context(), request)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ tp := printers.NewTreePrinter()
+ printers.TreeNodeTree(tp, resp.TreeRoot)
+ tp.Print()
+
+ return nil
+}
+
+var newLookupResourcesPageCallbackForTests func(readByPage uint)
+
+func lookupResourcesCmdFunc(cmd *cobra.Command, args []string) error {
+ objectNS := args[0]
+ relation := args[1]
+ subjectNS, subjectID, subjectRel, err := ParseSubject(args[2])
+ if err != nil {
+ return err
+ }
+
+ pageLimit := cobrautil.MustGetUint32(cmd, "page-limit")
+ caveatContext, err := GetCaveatContext(cmd)
+ if err != nil {
+ return err
+ }
+
+ consistency, err := consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ var cursor *v1.Cursor
+ if cursorStr := cobrautil.MustGetString(cmd, "cursor"); cursorStr != "" {
+ cursor = &v1.Cursor{Token: cursorStr}
+ }
+
+ var totalCount uint
+ for {
+ request := &v1.LookupResourcesRequest{
+ ResourceObjectType: objectNS,
+ Permission: relation,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: subjectNS,
+ ObjectId: subjectID,
+ },
+ OptionalRelation: subjectRel,
+ },
+ Context: caveatContext,
+ Consistency: consistency,
+ OptionalLimit: pageLimit,
+ OptionalCursor: cursor,
+ }
+ log.Trace().Interface("request", request).Uint32("page-limit", pageLimit).Send()
+
+ respStream, err := client.LookupResources(cmd.Context(), request)
+ if err != nil {
+ return err
+ }
+
+ var count uint
+
+ stream:
+ for {
+ resp, err := respStream.Recv()
+ switch {
+ case errors.Is(err, io.EOF):
+ break stream
+ case err != nil:
+ return err
+ default:
+ count++
+ totalCount++
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ }
+
+ console.Println(prettyLookupPermissionship(resp.ResourceObjectId, resp.Permissionship, resp.PartialCaveatInfo))
+ cursor = resp.AfterResultCursor
+ }
+ }
+
+ if newLookupResourcesPageCallbackForTests != nil {
+ newLookupResourcesPageCallbackForTests(count)
+ }
+ if count == 0 || pageLimit == 0 || count < uint(pageLimit) {
+ log.Trace().Interface("request", request).Uint32("page-limit", pageLimit).Uint("count", totalCount).Send()
+ break
+ }
+ }
+
+ showCursor := cobrautil.MustGetBool(cmd, "show-cursor")
+ if showCursor && cursor != nil {
+ console.Printf("Last cursor: %s\n", cursor.Token)
+ }
+
+ return nil
+}
+
+func lookupSubjectsCmdFunc(cmd *cobra.Command, args []string) error {
+ var objectNS, objectID string
+ err := stringz.SplitExact(args[0], ":", &objectNS, &objectID)
+ if err != nil {
+ return err
+ }
+
+ permission := args[1]
+
+ subjectType, subjectRelation := ParseType(args[2])
+
+ caveatContext, err := GetCaveatContext(cmd)
+ if err != nil {
+ return err
+ }
+
+ consistency, err := consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ request := &v1.LookupSubjectsRequest{
+ Resource: &v1.ObjectReference{
+ ObjectType: objectNS,
+ ObjectId: objectID,
+ },
+ Permission: permission,
+ SubjectObjectType: subjectType,
+ OptionalSubjectRelation: subjectRelation,
+ Context: caveatContext,
+ Consistency: consistency,
+ }
+ log.Trace().Interface("request", request).Send()
+
+ respStream, err := client.LookupSubjects(cmd.Context(), request)
+ if err != nil {
+ return err
+ }
+
+ for {
+ resp, err := respStream.Recv()
+ switch {
+ case errors.Is(err, io.EOF):
+ return nil
+ case err != nil:
+ return err
+ default:
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ }
+ console.Printf("%s:%s%s\n",
+ subjectType,
+ prettyLookupPermissionship(resp.Subject.SubjectObjectId, resp.Subject.Permissionship, resp.Subject.PartialCaveatInfo),
+ excludedSubjectsString(resp.ExcludedSubjects),
+ )
+ }
+ }
+}
+
+func excludedSubjectsString(excluded []*v1.ResolvedSubject) string {
+ if len(excluded) == 0 {
+ return ""
+ }
+
+ var b strings.Builder
+ fmt.Fprintf(&b, " - {\n")
+ for _, subj := range excluded {
+ fmt.Fprintf(&b, "\t%s\n", prettyLookupPermissionship(
+ subj.SubjectObjectId,
+ subj.Permissionship,
+ subj.PartialCaveatInfo,
+ ))
+ }
+ fmt.Fprintf(&b, "}")
+ return b.String()
+}
+
+func prettyLookupPermissionship(objectID string, p v1.LookupPermissionship, info *v1.PartialCaveatInfo) string {
+ var b strings.Builder
+ fmt.Fprint(&b, objectID)
+ if p == v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION {
+ fmt.Fprintf(&b, " (caveated, missing context: %s)", strings.Join(info.MissingRequiredContext, ", "))
+ }
+ return b.String()
+}
+
+func displayDebugInformationIfRequested(cmd *cobra.Command, debug *v1.DebugInformation, trailerMD metadata.MD, hasError bool) error {
+ if cobrautil.MustGetBool(cmd, "explain") || cobrautil.MustGetBool(cmd, "schema") {
+ debugInfo := &v1.DebugInformation{}
+ // DebugInformation comes in trailer < 1.30, and in response payload >= 1.30
+ if debug == nil {
+ found, err := responsemeta.GetResponseTrailerMetadataOrNil(trailerMD, responsemeta.DebugInformation)
+ if err != nil {
+ return err
+ }
+
+ if found == nil {
+ log.Warn().Msg("No debugging information returned for the check")
+ return nil
+ }
+
+ err = protojson.Unmarshal([]byte(*found), debugInfo)
+ if err != nil {
+ return err
+ }
+ } else {
+ debugInfo = debug
+ }
+
+ if debugInfo.Check == nil {
+ log.Warn().Msg("No trace found for the check")
+ return nil
+ }
+
+ if cobrautil.MustGetBool(cmd, "explain") {
+ tp := printers.NewTreePrinter()
+ printers.DisplayCheckTrace(debugInfo.Check, tp, hasError)
+ tp.Print()
+ }
+
+ if cobrautil.MustGetBool(cmd, "schema") {
+ console.Println()
+ console.Println(debugInfo.SchemaUsed)
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship.go b/vendor/github.com/authzed/zed/internal/commands/relationship.go
new file mode 100644
index 0000000..0471f85
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/relationship.go
@@ -0,0 +1,561 @@
+package commands
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+ "time"
+ "unicode"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "google.golang.org/genproto/googleapis/rpc/errdetails"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+)
+
+func RegisterRelationshipCmd(rootCmd *cobra.Command) *cobra.Command {
+ rootCmd.AddCommand(relationshipCmd)
+
+ relationshipCmd.AddCommand(createCmd)
+ createCmd.Flags().Bool("json", false, "output as JSON")
+ createCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`)
+ createCmd.Flags().String("expiration-time", "", `the expiration time of the relationship in RFC 3339 format`)
+ createCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin")
+
+ relationshipCmd.AddCommand(touchCmd)
+ touchCmd.Flags().Bool("json", false, "output as JSON")
+ touchCmd.Flags().String("caveat", "", `the caveat for the relationship, with format: 'caveat_name:{"some":"context"}'`)
+ touchCmd.Flags().String("expiration-time", "", `the expiration time for the relationship in RFC 3339 format`)
+ touchCmd.Flags().IntP("batch-size", "b", 100, "batch size when writing streams of relationships from stdin")
+
+ relationshipCmd.AddCommand(deleteCmd)
+ deleteCmd.Flags().Bool("json", false, "output as JSON")
+ deleteCmd.Flags().IntP("batch-size", "b", 100, "batch size when deleting streams of relationships from stdin")
+
+ relationshipCmd.AddCommand(readCmd)
+ readCmd.Flags().Bool("json", false, "output as JSON")
+ readCmd.Flags().String("revision", "", "optional revision at which to check")
+ _ = readCmd.Flags().MarkHidden("revision")
+ readCmd.Flags().String("subject-filter", "", "optional subject filter")
+ readCmd.Flags().Uint32("page-limit", 100, "limit of relations returned per page")
+ registerConsistencyFlags(readCmd.Flags())
+
+ relationshipCmd.AddCommand(bulkDeleteCmd)
+ bulkDeleteCmd.Flags().Bool("force", false, "force deletion of all elements in batches defined by <optional-limit>")
+ bulkDeleteCmd.Flags().String("subject-filter", "", "optional subject filter")
+ bulkDeleteCmd.Flags().Uint32("optional-limit", 1000, "the max amount of elements to delete. If you want to delete all in batches of size <optional-limit>, set --force to true")
+ bulkDeleteCmd.Flags().Bool("estimate-count", true, "estimate the count of relationships to be deleted")
+ _ = bulkDeleteCmd.Flags().MarkDeprecated("estimate-count", "no longer used, make use of --optional-limit instead")
+ return relationshipCmd
+}
+
+var relationshipCmd = &cobra.Command{
+ Use: "relationship <subcommand>",
+ Short: "Query and mutate the relationships in a permissions system",
+}
+
+var createCmd = &cobra.Command{
+ Use: "create <resource:id> <relation> <subject:id#optional_subject_relation>",
+ Short: "Create a relationship for a subject",
+ Args: ValidationWrapper(StdinOrExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_CREATE, os.Stdin),
+}
+
+var touchCmd = &cobra.Command{
+ Use: "touch <resource:id> <relation> <subject:id#optional_subject_relation>",
+ Short: "Idempotently updates a relationship for a subject",
+ Args: ValidationWrapper(StdinOrExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_TOUCH, os.Stdin),
+}
+
+var deleteCmd = &cobra.Command{
+ Use: "delete <resource:id> <relation> <subject:id#optional_subject_relation>",
+ Short: "Deletes a relationship",
+ Args: ValidationWrapper(StdinOrExactArgs(3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_DELETE, os.Stdin),
+}
+
+const readCmdHelpLong = `Enumerates relationships matching the provided pattern.
+
+To filter returned relationships using a resource ID prefix, append a '%' to the resource ID:
+
+zed relationship read some-type:some-prefix-%
+`
+
+var readCmd = &cobra.Command{
+ Use: "read <resource_type:optional_resource_id> <optional_relation> <optional_subject_type:optional_subject_id#optional_subject_relation>",
+ Short: "Enumerates relationships matching the provided pattern",
+ Long: readCmdHelpLong,
+ Args: ValidationWrapper(cobra.RangeArgs(1, 3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: readRelationships,
+}
+
+var bulkDeleteCmd = &cobra.Command{
+ Use: "bulk-delete <resource_type:optional_resource_id> <optional_relation> <optional_subject_type:optional_subject_id#optional_subject_relation>",
+ Short: "Deletes relationships matching the provided pattern en masse",
+ Args: ValidationWrapper(cobra.RangeArgs(1, 3)),
+ ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation),
+ RunE: bulkDeleteRelationships,
+}
+
+func StdinOrExactArgs(n int) cobra.PositionalArgs {
+ return func(cmd *cobra.Command, args []string) error {
+ if ok := isArgsViaFile(os.Stdin) && len(args) == 0; ok {
+ return nil
+ }
+
+ return cobra.ExactArgs(n)(cmd, args)
+ }
+}
+
+func isArgsViaFile(file *os.File) bool {
+ return !isFileTerminal(file)
+}
+
+func bulkDeleteRelationships(cmd *cobra.Command, args []string) error {
+ spicedbClient, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ filter, err := buildRelationshipsFilter(cmd, args)
+ if err != nil {
+ return err
+ }
+
+ bar := console.CreateProgressBar("deleting relationships")
+ defer func() {
+ _ = bar.Finish()
+ }()
+
+ allowPartialDeletions := cobrautil.MustGetBool(cmd, "force")
+ optionalLimit := cobrautil.MustGetUint32(cmd, "optional-limit")
+
+ var resp *v1.DeleteRelationshipsResponse
+ for {
+ delRequest := &v1.DeleteRelationshipsRequest{
+ RelationshipFilter: filter,
+ OptionalLimit: optionalLimit,
+ OptionalAllowPartialDeletions: allowPartialDeletions,
+ }
+ log.Trace().Interface("request", delRequest).Msg("deleting relationships")
+
+ resp, err = spicedbClient.DeleteRelationships(cmd.Context(), delRequest)
+ if errorInfo, ok := grpcErrorInfoFrom(err); ok {
+ if errorInfo.GetReason() == v1.ErrorReason_ERROR_REASON_TOO_MANY_RELATIONSHIPS_FOR_TRANSACTIONAL_DELETE.String() {
+ resourceType := "relationships"
+ if returnedResourceType, ok := errorInfo.GetMetadata()["filter_resource_type"]; ok {
+ resourceType = returnedResourceType
+ }
+
+ return fmt.Errorf("could not delete %s, as more than %s relationships were found. Consider increasing --optional-limit or deleting all relationships using --force",
+ resourceType,
+ errorInfo.GetMetadata()["limit"])
+ }
+ }
+ if err != nil {
+ return err
+ }
+
+ if resp.DeletionProgress == v1.DeleteRelationshipsResponse_DELETION_PROGRESS_COMPLETE {
+ break
+ }
+
+ if err := bar.Add(int(optionalLimit)); err != nil {
+ return err
+ }
+ }
+
+ _ = bar.Finish()
+ console.Println(resp.DeletedAt.GetToken())
+ return nil
+}
+
+func grpcErrorInfoFrom(err error) (*errdetails.ErrorInfo, bool) {
+ if err == nil {
+ return nil, false
+ }
+
+ if s, ok := status.FromError(err); ok {
+ for _, d := range s.Details() {
+ if errInfo, ok := d.(*errdetails.ErrorInfo); ok {
+ return errInfo, true
+ }
+ }
+ }
+
+ return nil, false
+}
+
+func buildRelationshipsFilter(cmd *cobra.Command, args []string) (*v1.RelationshipFilter, error) {
+ filter := &v1.RelationshipFilter{ResourceType: args[0]}
+
+ if strings.Contains(args[0], ":") {
+ var resourceID string
+ err := stringz.SplitExact(args[0], ":", &filter.ResourceType, &resourceID)
+ if err != nil {
+ return nil, err
+ }
+
+ if strings.HasSuffix(resourceID, "%") {
+ filter.OptionalResourceIdPrefix = strings.TrimSuffix(resourceID, "%")
+ } else {
+ filter.OptionalResourceId = resourceID
+ }
+ }
+
+ if len(args) > 1 {
+ filter.OptionalRelation = args[1]
+ }
+
+ subjectFilter := cobrautil.MustGetString(cmd, "subject-filter")
+ if len(args) == 3 {
+ if subjectFilter != "" {
+ return nil, errors.New("cannot specify subject filter both positionally and via --subject-filter")
+ }
+ subjectFilter = args[2]
+ }
+
+ if subjectFilter != "" {
+ if strings.Contains(subjectFilter, ":") {
+ subjectNS, subjectID, subjectRel, err := ParseSubject(subjectFilter)
+ if err != nil {
+ return nil, err
+ }
+
+ filter.OptionalSubjectFilter = &v1.SubjectFilter{
+ SubjectType: subjectNS,
+ OptionalSubjectId: subjectID,
+ OptionalRelation: &v1.SubjectFilter_RelationFilter{
+ Relation: subjectRel,
+ },
+ }
+ } else {
+ filter.OptionalSubjectFilter = &v1.SubjectFilter{
+ SubjectType: subjectFilter,
+ }
+ }
+ }
+
+ return filter, nil
+}
+
+func readRelationships(cmd *cobra.Command, args []string) error {
+ spicedbClient, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ filter, err := buildRelationshipsFilter(cmd, args)
+ if err != nil {
+ return err
+ }
+
+ request := &v1.ReadRelationshipsRequest{RelationshipFilter: filter}
+
+ limit := cobrautil.MustGetUint32(cmd, "page-limit")
+ request.OptionalLimit = limit
+ request.Consistency, err = consistencyFromCmd(cmd)
+ if err != nil {
+ return err
+ }
+
+ lastCursor := request.OptionalCursor
+ for {
+ request.OptionalCursor = lastCursor
+ var cursorToken string
+ if lastCursor != nil {
+ cursorToken = lastCursor.Token
+ }
+ log.Trace().Interface("request", request).Str("cursor", cursorToken).Msg("reading relationships page")
+ readRelClient, err := spicedbClient.ReadRelationships(cmd.Context(), request)
+ if err != nil {
+ return err
+ }
+
+ var relCount uint32
+ for {
+ if err := cmd.Context().Err(); err != nil {
+ return err
+ }
+
+ msg, err := readRelClient.Recv()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
+ if err != nil {
+ return err
+ }
+
+ lastCursor = msg.AfterResultCursor
+ relCount++
+ if err := printRelationship(cmd, msg); err != nil {
+ return err
+ }
+ }
+
+ if relCount < limit || limit == 0 {
+ return nil
+ }
+
+ if relCount > limit {
+ log.Warn().Uint32("limit-specified", limit).Uint32("relationships-received", relCount).Msg("page limit ignored, pagination may not be supported by the server, consider updating SpiceDB")
+ return nil
+ }
+ }
+}
+
+func printRelationship(cmd *cobra.Command, msg *v1.ReadRelationshipsResponse) error {
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(msg)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ } else {
+ relString, err := relationshipToString(msg.Relationship)
+ if err != nil {
+ return err
+ }
+ console.Println(relString)
+ }
+
+ return nil
+}
+
+func argsToRelationship(args []string) (*v1.Relationship, error) {
+ if len(args) != 3 {
+ return nil, fmt.Errorf("expected 3 arguments, but got %d", len(args))
+ }
+
+ rel, err := tupleToRel(args[0], args[1], args[2])
+ if err != nil {
+ return nil, errors.New("failed to parse input arguments")
+ }
+
+ return rel, nil
+}
+
+func relationshipToString(rel *v1.Relationship) (string, error) {
+ relString, err := tuple.V1StringRelationship(rel)
+ if err != nil {
+ return "", err
+ }
+
+ relString = strings.Replace(relString, "@", " ", 1)
+ relString = strings.Replace(relString, "#", " ", 1)
+ return relString, nil
+}
+
+// parseRelationshipLine splits a line of update input that comes from stdin
+// and returns the fields representing the 3 arguments. This is to handle
+// the fact that relationships specified via stdin can't escape spaces like
+// shell arguments.
+func parseRelationshipLine(line string) (string, string, string, error) {
+ line = strings.TrimSpace(line)
+ resourceIdx := strings.IndexFunc(line, unicode.IsSpace)
+ if resourceIdx == -1 {
+ args := 0
+ if line != "" {
+ args = 1
+ }
+ return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got %v", line, args)
+ }
+
+ resource := line[:resourceIdx]
+ rest := strings.TrimSpace(line[resourceIdx+1:])
+ relationIdx := strings.IndexFunc(rest, unicode.IsSpace)
+ if relationIdx == -1 {
+ args := 1
+ if strings.TrimSpace(rest) != "" {
+ args = 2
+ }
+ return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got %v", line, args)
+ }
+
+ relation := rest[:relationIdx]
+ rest = strings.TrimSpace(rest[relationIdx+1:])
+ if rest == "" {
+ return "", "", "", fmt.Errorf("expected %s to have 3 arguments, but got 2", line)
+ }
+
+ return resource, relation, rest, nil
+}
+
+func FileRelationshipParser(f *os.File) RelationshipParser {
+ scanner := bufio.NewScanner(f)
+ return func() (*v1.Relationship, error) {
+ if scanner.Scan() {
+ res, rel, subj, err := parseRelationshipLine(scanner.Text())
+ if err != nil {
+ return nil, err
+ }
+ return tupleToRel(res, rel, subj)
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return nil, ErrExhaustedRelationships
+ }
+}
+
+func tupleToRel(resource, relation, subject string) (*v1.Relationship, error) {
+ return tuple.ParseV1Rel(resource + "#" + relation + "@" + subject)
+}
+
+func SliceRelationshipParser(args []string) RelationshipParser {
+ ran := false
+ return func() (*v1.Relationship, error) {
+ if ran {
+ return nil, ErrExhaustedRelationships
+ }
+ ran = true
+ return tupleToRel(args[0], args[1], args[2])
+ }
+}
+
+func writeUpdates(ctx context.Context, spicedbClient client.Client, updates []*v1.RelationshipUpdate, json bool) error {
+ if len(updates) == 0 {
+ return nil
+ }
+ request := &v1.WriteRelationshipsRequest{
+ Updates: updates,
+ OptionalPreconditions: nil,
+ }
+
+ log.Trace().Interface("request", request).Msg("writing relationships")
+ resp, err := spicedbClient.WriteRelationships(ctx, request)
+ if err != nil {
+ return err
+ }
+
+ if json {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ } else {
+ console.Println(resp.WrittenAt.GetToken())
+ }
+
+ return nil
+}
+
+// RelationshipParser is a closure that can produce relationships.
+// When there are no more relationships, it will return ErrExhaustedRelationships.
+type RelationshipParser func() (*v1.Relationship, error)
+
+// ErrExhaustedRelationships signals that the last producible value of a RelationshipParser
+// has already been consumed.
+// Functions should return this error to signal a graceful end of input.
+var ErrExhaustedRelationships = errors.New("exhausted all relationships")
+
+func writeRelationshipCmdFunc(operation v1.RelationshipUpdate_Operation, input *os.File) func(cmd *cobra.Command, args []string) error {
+ return func(cmd *cobra.Command, args []string) error {
+ parser := SliceRelationshipParser(args)
+ if isArgsViaFile(input) && len(args) == 0 {
+ parser = FileRelationshipParser(input)
+ }
+
+ spicedbClient, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ batchSize := cobrautil.MustGetInt(cmd, "batch-size")
+ updateBatch := make([]*v1.RelationshipUpdate, 0)
+ doJSON := cobrautil.MustGetBool(cmd, "json")
+
+ for {
+ rel, err := parser()
+ if errors.Is(err, ErrExhaustedRelationships) {
+ return writeUpdates(cmd.Context(), spicedbClient, updateBatch, doJSON)
+ } else if err != nil {
+ return err
+ }
+
+ if operation != v1.RelationshipUpdate_OPERATION_DELETE {
+ if err := handleCaveatFlag(cmd, rel); err != nil {
+ return err
+ }
+
+ if err := handleExpirationFlag(cmd, rel); err != nil {
+ return err
+ }
+ }
+
+ updateBatch = append(updateBatch, &v1.RelationshipUpdate{
+ Operation: operation,
+ Relationship: rel,
+ })
+ if len(updateBatch) == batchSize {
+ if err := writeUpdates(cmd.Context(), spicedbClient, updateBatch, doJSON); err != nil {
+ return err
+ }
+ updateBatch = nil
+ }
+ }
+ }
+}
+
+func handleCaveatFlag(cmd *cobra.Command, rel *v1.Relationship) error {
+ caveatString := cobrautil.MustGetString(cmd, "caveat")
+ if caveatString != "" {
+ if rel.OptionalCaveat != nil {
+ return errors.New("cannot specify a caveat in both the relationship and the --caveat flag")
+ }
+
+ parts := strings.SplitN(caveatString, ":", 2)
+ if len(parts) == 0 {
+ return fmt.Errorf("invalid --caveat argument. Must be in format `caveat_name:context`, but found `%s`", caveatString)
+ }
+
+ rel.OptionalCaveat = &v1.ContextualizedCaveat{
+ CaveatName: parts[0],
+ }
+
+ if len(parts) == 2 {
+ caveatCtx, err := ParseCaveatContext(parts[1])
+ if err != nil {
+ return err
+ }
+ rel.OptionalCaveat.Context = caveatCtx
+ }
+ }
+ return nil
+}
+
+func handleExpirationFlag(cmd *cobra.Command, rel *v1.Relationship) error {
+ expirationTime := cobrautil.MustGetString(cmd, "expiration-time")
+
+ if expirationTime != "" {
+ t, err := time.Parse(time.RFC3339, expirationTime)
+ if err != nil {
+ return fmt.Errorf("could not parse RFC 3339 timestamp: %w", err)
+ }
+ rel.OptionalExpiresAt = timestamppb.New(t)
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go b/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go
new file mode 100644
index 0000000..ea94114
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/relationship_nowasm.go
@@ -0,0 +1,12 @@
+//go:build !wasm
+// +build !wasm
+
+package commands
+
+import (
+ "os"
+
+ "golang.org/x/term"
+)
+
+var isFileTerminal = func(f *os.File) bool { return term.IsTerminal(int(f.Fd())) }
diff --git a/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go b/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go
new file mode 100644
index 0000000..4388932
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/relationship_wasm.go
@@ -0,0 +1,5 @@
+package commands
+
+import "os"
+
+var isFileTerminal = func(f *os.File) bool { return true }
diff --git a/vendor/github.com/authzed/zed/internal/commands/schema.go b/vendor/github.com/authzed/zed/internal/commands/schema.go
new file mode 100644
index 0000000..b56f152
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/schema.go
@@ -0,0 +1,87 @@
+package commands
+
+import (
+ "context"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+)
+
+func RegisterSchemaCmd(rootCmd *cobra.Command) *cobra.Command {
+ rootCmd.AddCommand(schemaCmd)
+
+ schemaCmd.AddCommand(schemaReadCmd)
+ schemaReadCmd.Flags().Bool("json", false, "output as JSON")
+
+ return schemaCmd
+}
+
+var (
+ schemaCmd = &cobra.Command{
+ Use: "schema <subcommand>",
+ Short: "Manage schema for a permissions system",
+ }
+
+ schemaReadCmd = &cobra.Command{
+ Use: "read",
+ Short: "Read the schema of a permissions system",
+ Args: ValidationWrapper(cobra.ExactArgs(0)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: schemaReadCmdFunc,
+ }
+)
+
+func schemaReadCmdFunc(cmd *cobra.Command, _ []string) error {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ request := &v1.ReadSchemaRequest{}
+ log.Trace().Interface("request", request).Msg("requesting schema read")
+
+ resp, err := client.ReadSchema(cmd.Context(), request)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := PrettyProto(resp)
+ if err != nil {
+ return err
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ console.Println(stringz.Join("\n\n", resp.SchemaText))
+ return nil
+}
+
+// ReadSchema calls read schema for the client and returns the schema found.
+func ReadSchema(ctx context.Context, client client.Client) (string, error) {
+ request := &v1.ReadSchemaRequest{}
+ log.Trace().Interface("request", request).Msg("requesting schema read")
+
+ resp, err := client.ReadSchema(ctx, request)
+ if err != nil {
+ errStatus, ok := status.FromError(err)
+ if !ok || errStatus.Code() != codes.NotFound {
+ return "", err
+ }
+
+ log.Debug().Msg("no schema defined")
+ return "", nil
+ }
+
+ return resp.SchemaText, nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/util.go b/vendor/github.com/authzed/zed/internal/commands/util.go
new file mode 100644
index 0000000..6d5b4da
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/util.go
@@ -0,0 +1,123 @@
+package commands
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/TylerBrock/colorjson"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/spf13/cobra"
+ "google.golang.org/protobuf/encoding/protojson"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+)
+
+// ParseSubject parses the given subject string into its namespace, object ID
+// and relation, if valid.
+func ParseSubject(s string) (namespace, id, relation string, err error) {
+ err = stringz.SplitExact(s, ":", &namespace, &id)
+ if err != nil {
+ return
+ }
+ err = stringz.SplitExact(id, "#", &id, &relation)
+ if err != nil {
+ relation = ""
+ err = nil
+ }
+ return
+}
+
+// ParseType parses a type reference of the form `namespace#relaion`.
+func ParseType(s string) (namespace, relation string) {
+ namespace, relation, _ = strings.Cut(s, "#")
+ return
+}
+
+// GetCaveatContext returns the entered caveat caveat, if any.
+func GetCaveatContext(cmd *cobra.Command) (*structpb.Struct, error) {
+ contextString := cobrautil.MustGetString(cmd, "caveat-context")
+ if len(contextString) == 0 {
+ return nil, nil
+ }
+
+ return ParseCaveatContext(contextString)
+}
+
+// ParseCaveatContext parses the given context JSON string into caveat context,
+// if valid.
+func ParseCaveatContext(contextString string) (*structpb.Struct, error) {
+ contextMap := map[string]any{}
+ err := json.Unmarshal([]byte(contextString), &contextMap)
+ if err != nil {
+ return nil, fmt.Errorf("invalid caveat context JSON: %w", err)
+ }
+
+ context, err := structpb.NewStruct(contextMap)
+ if err != nil {
+ return nil, fmt.Errorf("could not construct caveat context: %w", err)
+ }
+ return context, err
+}
+
+// PrettyProto returns the given protocol buffer formatted into pretty text.
+func PrettyProto(m proto.Message) ([]byte, error) {
+ encoded, err := protojson.Marshal(m)
+ if err != nil {
+ return nil, err
+ }
+ var obj interface{}
+ err = json.Unmarshal(encoded, &obj)
+ if err != nil {
+ panic("protojson decode failed: " + err.Error())
+ }
+
+ f := colorjson.NewFormatter()
+ f.Indent = 2
+ pretty, err := f.Marshal(obj)
+ if err != nil {
+ panic("colorjson encode failed: " + err.Error())
+ }
+
+ return pretty, nil
+}
+
+// InjectRequestID adds the value of the --request-id flag to the
+// context of the given command.
+func InjectRequestID(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ requestID := cobrautil.MustGetString(cmd, "request-id")
+ if ctx != nil && requestID != "" {
+ cmd.SetContext(requestmeta.WithRequestID(ctx, requestID))
+ }
+
+ return nil
+}
+
+// ValidationError is used to wrap errors that are cobra validation errors. It should be used to
+// wrap the Command.PositionalArgs function in order to be able to determine if the error is a validation error.
+// This is used to determine if an error should print the usage string. Unfortunately Cobra parameter parsing
+// and parameter validation are handled differently, and the latter does not trigger calling Command.FlagErrorFunc
+type ValidationError struct {
+ error
+}
+
+func (ve ValidationError) Is(err error) bool {
+ var validationError ValidationError
+ return errors.As(err, &validationError)
+}
+
+// ValidationWrapper is used to be able to determine if an error is a validation error.
+func ValidationWrapper(f cobra.PositionalArgs) cobra.PositionalArgs {
+ return func(cmd *cobra.Command, args []string) error {
+ if err := f(cmd, args); err != nil {
+ return ValidationError{error: err}
+ }
+
+ return nil
+ }
+}
diff --git a/vendor/github.com/authzed/zed/internal/commands/watch.go b/vendor/github.com/authzed/zed/internal/commands/watch.go
new file mode 100644
index 0000000..96b3451
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/commands/watch.go
@@ -0,0 +1,212 @@
+package commands
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/spf13/cobra"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+)
+
+var (
+ watchObjectTypes []string
+ watchRevision string
+ watchTimestamps bool
+ watchRelationshipFilters []string
+)
+
+func RegisterWatchCmd(rootCmd *cobra.Command) *cobra.Command {
+ rootCmd.AddCommand(watchCmd)
+
+ watchCmd.Flags().StringSliceVar(&watchObjectTypes, "object_types", nil, "optional object types to watch updates for")
+ watchCmd.Flags().StringVar(&watchRevision, "revision", "", "optional revision at which to start watching")
+ watchCmd.Flags().BoolVar(&watchTimestamps, "timestamp", false, "shows timestamp of incoming update events")
+ return watchCmd
+}
+
+func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command {
+ parentCmd.AddCommand(watchRelationshipsCmd)
+ watchRelationshipsCmd.Flags().StringSliceVar(&watchObjectTypes, "object_types", nil, "optional object types to watch updates for")
+ watchRelationshipsCmd.Flags().StringVar(&watchRevision, "revision", "", "optional revision at which to start watching")
+ watchRelationshipsCmd.Flags().BoolVar(&watchTimestamps, "timestamp", false, "shows timestamp of incoming update events")
+ watchRelationshipsCmd.Flags().StringSliceVar(&watchRelationshipFilters, "filter", nil, "optional filter(s) for the watch stream. Example: `optional_resource_type:optional_resource_id_or_prefix#optional_relation@optional_subject_filter`")
+ return watchRelationshipsCmd
+}
+
+var watchCmd = &cobra.Command{
+ Use: "watch [object_types, ...] [start_cursor]",
+ Short: "Watches the stream of relationship updates from the server",
+ Args: ValidationWrapper(cobra.RangeArgs(0, 2)),
+ RunE: watchCmdFunc,
+ Deprecated: "deprecated; please use `zed watch relationships` instead",
+}
+
+var watchRelationshipsCmd = &cobra.Command{
+ Use: "watch [object_types, ...] [start_cursor]",
+ Short: "Watches the stream of relationship updates from the server",
+ Args: ValidationWrapper(cobra.RangeArgs(0, 2)),
+ RunE: watchCmdFunc,
+}
+
+func watchCmdFunc(cmd *cobra.Command, _ []string) error {
+ console.Printf("starting watch stream over types %v and revision %v\n", watchObjectTypes, watchRevision)
+
+ cli, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ relFilters := make([]*v1.RelationshipFilter, 0, len(watchRelationshipFilters))
+ for _, filter := range watchRelationshipFilters {
+ relFilter, err := parseRelationshipFilter(filter)
+ if err != nil {
+ return err
+ }
+ relFilters = append(relFilters, relFilter)
+ }
+
+ req := &v1.WatchRequest{
+ OptionalObjectTypes: watchObjectTypes,
+ OptionalRelationshipFilters: relFilters,
+ }
+ if watchRevision != "" {
+ req.OptionalStartCursor = &v1.ZedToken{Token: watchRevision}
+ }
+
+ ctx, cancel := context.WithCancel(cmd.Context())
+ defer cancel()
+
+ signalctx, interruptCancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
+ defer interruptCancel()
+
+ watchStream, err := cli.Watch(ctx, req)
+ if err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case <-signalctx.Done():
+ console.Errorf("stream interrupted after program termination\n")
+ return nil
+ case <-ctx.Done():
+ console.Errorf("stream canceled after context cancellation\n")
+ return nil
+ default:
+ resp, err := watchStream.Recv()
+ if err != nil {
+ return err
+ }
+
+ for _, update := range resp.Updates {
+ if watchTimestamps {
+ console.Printf("%v: ", time.Now())
+ }
+
+ switch update.Operation {
+ case v1.RelationshipUpdate_OPERATION_CREATE:
+ console.Printf("CREATED ")
+
+ case v1.RelationshipUpdate_OPERATION_DELETE:
+ console.Printf("DELETED ")
+
+ case v1.RelationshipUpdate_OPERATION_TOUCH:
+ console.Printf("TOUCHED ")
+ }
+
+ subjectRelation := ""
+ if update.Relationship.Subject.OptionalRelation != "" {
+ subjectRelation = " " + update.Relationship.Subject.OptionalRelation
+ }
+
+ console.Printf("%s:%s %s %s:%s%s\n",
+ update.Relationship.Resource.ObjectType,
+ update.Relationship.Resource.ObjectId,
+ update.Relationship.Relation,
+ update.Relationship.Subject.Object.ObjectType,
+ update.Relationship.Subject.Object.ObjectId,
+ subjectRelation,
+ )
+ }
+ }
+ }
+}
+
+func parseRelationshipFilter(relFilterStr string) (*v1.RelationshipFilter, error) {
+ relFilter := &v1.RelationshipFilter{}
+ pieces := strings.Split(relFilterStr, "@")
+ if len(pieces) > 2 {
+ return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
+ }
+
+ if len(pieces) == 2 {
+ subjectFilter, err := parseSubjectFilter(pieces[1])
+ if err != nil {
+ return nil, err
+ }
+ relFilter.OptionalSubjectFilter = subjectFilter
+ }
+
+ if len(pieces) > 0 {
+ resourcePieces := strings.Split(pieces[0], "#")
+ if len(resourcePieces) > 2 {
+ return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
+ }
+
+ if len(resourcePieces) == 2 {
+ relFilter.OptionalRelation = resourcePieces[1]
+ }
+
+ resourceTypePieces := strings.Split(resourcePieces[0], ":")
+ if len(resourceTypePieces) > 2 {
+ return nil, fmt.Errorf("invalid relationship filter: %s", relFilterStr)
+ }
+
+ relFilter.ResourceType = resourceTypePieces[0]
+ if len(resourceTypePieces) == 2 {
+ optionalResourceIDOrPrefix := resourceTypePieces[1]
+ if strings.HasSuffix(optionalResourceIDOrPrefix, "%") {
+ relFilter.OptionalResourceIdPrefix = strings.TrimSuffix(optionalResourceIDOrPrefix, "%")
+ } else {
+ relFilter.OptionalResourceId = optionalResourceIDOrPrefix
+ }
+ }
+ }
+
+ return relFilter, nil
+}
+
+func parseSubjectFilter(subjectFilterStr string) (*v1.SubjectFilter, error) {
+ subjectFilter := &v1.SubjectFilter{}
+ pieces := strings.Split(subjectFilterStr, "#")
+ if len(pieces) > 2 {
+ return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr)
+ }
+
+ subjectTypePieces := strings.Split(pieces[0], ":")
+ if len(subjectTypePieces) > 2 {
+ return nil, fmt.Errorf("invalid subject filter: %s", subjectFilterStr)
+ }
+
+ subjectFilter.SubjectType = subjectTypePieces[0]
+ if len(subjectTypePieces) == 2 {
+ subjectFilter.OptionalSubjectId = subjectTypePieces[1]
+ }
+
+ if len(pieces) == 2 {
+ subjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{
+ Relation: pieces[1],
+ }
+ }
+
+ return subjectFilter, nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/console/console.go b/vendor/github.com/authzed/zed/internal/console/console.go
new file mode 100644
index 0000000..53f815c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/console/console.go
@@ -0,0 +1,60 @@
+package console
+
+import (
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/mattn/go-isatty"
+ "github.com/schollz/progressbar/v3"
+)
+
+// Printf defines an (overridable) function for printing to the console via stdout.
+var Printf = func(format string, a ...any) {
+ fmt.Printf(format, a...)
+}
+
+var Print = func(a ...any) {
+ fmt.Print(a...)
+}
+
+// Errorf defines an (overridable) function for printing to the console via stderr.
+var Errorf = func(format string, a ...any) {
+ _, err := fmt.Fprintf(os.Stderr, format, a...)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Println prints a line with optional values to the console.
+var Println = func(values ...any) {
+ for _, value := range values {
+ Printf("%v\n", value)
+ }
+}
+
+// CreateProgressBar creates a new progress bar with the given description and defaults adjusted to zed's UX experience
+func CreateProgressBar(description string) *progressbar.ProgressBar {
+ bar := progressbar.NewOptions(-1,
+ progressbar.OptionSetWidth(10),
+ progressbar.OptionSetRenderBlankState(true),
+ progressbar.OptionSetVisibility(false),
+ )
+ if isatty.IsTerminal(os.Stderr.Fd()) {
+ bar = progressbar.NewOptions64(-1,
+ progressbar.OptionSetDescription(description),
+ progressbar.OptionSetWriter(os.Stderr),
+ progressbar.OptionSetWidth(10),
+ progressbar.OptionThrottle(65*time.Millisecond),
+ progressbar.OptionShowCount(),
+ progressbar.OptionShowIts(),
+ progressbar.OptionSetItsString("relationship"),
+ progressbar.OptionOnCompletion(func() { _, _ = fmt.Fprint(os.Stderr, "\n") }),
+ progressbar.OptionSpinnerType(14),
+ progressbar.OptionFullWidth(),
+ progressbar.OptionSetRenderBlankState(true),
+ )
+ }
+
+ return bar
+}
diff --git a/vendor/github.com/authzed/zed/internal/decode/decoder.go b/vendor/github.com/authzed/zed/internal/decode/decoder.go
new file mode 100644
index 0000000..f8a02ad
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/decode/decoder.go
@@ -0,0 +1,215 @@
+package decode
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "github.com/rs/zerolog/log"
+ "gopkg.in/yaml.v3"
+
+ composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/composableschemadsl/generator"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/validationfile"
+ "github.com/authzed/spicedb/pkg/validationfile/blocks"
+)
+
+var playgroundPattern = regexp.MustCompile("^.*/s/.*/schema|relationships|assertions|expected.*$")
+
+// SchemaRelationships holds the schema (as a string) and a list of
+// relationships (as a string) in the format from the devtools download API.
+type SchemaRelationships struct {
+ Schema string `yaml:"schema"`
+ Relationships string `yaml:"relationships"`
+}
+
+// Func will decode into the supplied object.
+type Func func(out interface{}) ([]byte, bool, error)
+
+// DecoderForURL returns the appropriate decoder for a given URL.
+// Some URLs have special handling to dereference to the actual file.
+func DecoderForURL(u *url.URL) (d Func, err error) {
+ switch s := u.Scheme; s {
+ case "", "file":
+ d = fileDecoder(u)
+ case "http", "https":
+ d = httpDecoder(u)
+ default:
+ err = fmt.Errorf("%s scheme not supported", s)
+ }
+ return
+}
+
+func fileDecoder(u *url.URL) Func {
+ return func(out interface{}) ([]byte, bool, error) {
+ file, err := os.Open(u.Path)
+ if err != nil {
+ return nil, false, err
+ }
+ data, err := io.ReadAll(file)
+ if err != nil {
+ return nil, false, err
+ }
+ isOnlySchema, err := unmarshalAsYAMLOrSchemaWithFile(data, out, u.Path)
+ return data, isOnlySchema, err
+ }
+}
+
+func httpDecoder(u *url.URL) Func {
+ rewriteURL(u)
+ return directHTTPDecoder(u)
+}
+
+func rewriteURL(u *url.URL) {
+ // match playground urls
+ if playgroundPattern.MatchString(u.Path) {
+ u.Path = u.Path[:strings.LastIndex(u.Path, "/")]
+ u.Path += "/download"
+ return
+ }
+
+ switch u.Hostname() {
+ case "gist.github.com":
+ u.Host = "gist.githubusercontent.com"
+ u.Path = path.Join(u.Path, "/raw")
+ case "pastebin.com":
+ if ok, _ := path.Match("/raw/*", u.Path); ok {
+ return
+ }
+ u.Path = path.Join("/raw/", u.Path)
+ }
+}
+
+func directHTTPDecoder(u *url.URL) Func {
+ return func(out interface{}) ([]byte, bool, error) {
+ log.Debug().Stringer("url", u).Send()
+ r, err := http.Get(u.String())
+ if err != nil {
+ return nil, false, err
+ }
+ defer r.Body.Close()
+ data, err := io.ReadAll(r.Body)
+ if err != nil {
+ return nil, false, err
+ }
+
+ isOnlySchema, err := unmarshalAsYAMLOrSchema("", data, out)
+ return data, isOnlySchema, err
+ }
+}
+
+// Uses the files passed in the args and looks for the specified schemaFile to parse the YAML.
+func unmarshalAsYAMLOrSchemaWithFile(data []byte, out interface{}, filename string) (bool, error) {
+ if strings.Contains(string(data), "schemaFile:") && !strings.Contains(string(data), "schema:") {
+ if err := yaml.Unmarshal(data, out); err != nil {
+ return false, err
+ }
+ validationFile, ok := out.(*validationfile.ValidationFile)
+ if !ok {
+ return false, fmt.Errorf("could not cast unmarshalled file to validationfile")
+ }
+
+ // Need to join the original filepath with the requested filepath
+ // to construct the path to the referenced schema file.
+ // NOTE: This does not allow for yaml files to transitively reference
+ // each other's schemaFile fields.
+ // TODO: enable this behavior
+ schemaPath := filepath.Join(path.Dir(filename), validationFile.SchemaFile)
+
+ if !filepath.IsLocal(schemaPath) {
+ // We want to prevent access of files that are outside of the folder
+ // where the command was originally invoked. This should do that.
+ return false, fmt.Errorf("schema filepath %s must be local to where the command was invoked", schemaPath)
+ }
+
+ file, err := os.Open(schemaPath)
+ if err != nil {
+ return false, err
+ }
+ data, err = io.ReadAll(file)
+ if err != nil {
+ return false, err
+ }
+ }
+ return unmarshalAsYAMLOrSchema(filename, data, out)
+}
+
+func unmarshalAsYAMLOrSchema(filename string, data []byte, out interface{}) (bool, error) {
+ inputString := string(data)
+
+ // Check for indications of a schema-only file.
+ if !strings.Contains(inputString, "schema:") && !strings.Contains(inputString, "relationships:") {
+ if err := compileSchemaFromData(filename, inputString, out); err != nil {
+ return false, err
+ }
+ return true, nil
+ }
+
+ if !strings.Contains(inputString, "schema:") && !strings.Contains(inputString, "schemaFile:") {
+ // If there is no schema and no schemaFile and it doesn't compile then it must be yaml with missing fields
+ if err := compileSchemaFromData(filename, inputString, out); err != nil {
+ return false, errors.New("either schema or schemaFile must be present")
+ }
+ return true, nil
+ }
+ // Try to unparse as YAML for the validation file format.
+ if err := yaml.Unmarshal(data, out); err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+// compileSchemaFromData attempts to compile using the old DSL and the new composable DSL,
+// but prefers the new DSL.
+// It returns the errors returned by both compilations.
+func compileSchemaFromData(filename, schemaString string, out interface{}) error {
+ var (
+ standardCompileErr error
+ composableCompiled *composable.CompiledSchema
+ composableCompileErr error
+ vfile validationfile.ValidationFile
+ )
+
+ vfile = *out.(*validationfile.ValidationFile)
+ vfile.Schema = blocks.ParsedSchema{
+ SourcePosition: spiceerrors.SourcePosition{LineNumber: 1, ColumnPosition: 1},
+ }
+
+ _, standardCompileErr = compiler.Compile(compiler.InputSchema{
+ Source: input.Source("schema"),
+ SchemaString: schemaString,
+ }, compiler.AllowUnprefixedObjectType())
+
+ if standardCompileErr == nil {
+ vfile.Schema.Schema = schemaString
+ }
+
+ inputSourceFolder := filepath.Dir(filename)
+ composableCompiled, composableCompileErr = composable.Compile(composable.InputSchema{
+ SchemaString: schemaString,
+ }, composable.AllowUnprefixedObjectType(), composable.SourceFolder(inputSourceFolder))
+
+ if composableCompileErr == nil {
+ compiledSchemaString, _, err := generator.GenerateSchema(composableCompiled.OrderedDefinitions)
+ if err != nil {
+ return fmt.Errorf("could not generate string schema: %w", err)
+ }
+ vfile.Schema.Schema = compiledSchemaString
+ }
+
+ err := errors.Join(standardCompileErr, composableCompileErr)
+
+ *out.(*validationfile.ValidationFile) = vfile
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/batch.go b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go
new file mode 100644
index 0000000..640085c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go
@@ -0,0 +1,68 @@
+package grpcutil
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime"
+
+ "golang.org/x/sync/errgroup"
+ "golang.org/x/sync/semaphore"
+)
+
+func minimum(a int, b int) int {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+// EachFunc is a callback function that is called for each batch. no is the
+// batch number, start is the starting index of this batch in the slice, and
+// end is the ending index of this batch in the slice.
+type EachFunc func(ctx context.Context, no int, start int, end int) error
+
+// ConcurrentBatch will calculate the minimum number of batches to required to batch n items
+// with batchSize batches. For each batch, it will execute the each function.
+// These functions will be processed in parallel using maxWorkers number of
+// goroutines. If maxWorkers is 1, then batching will happen sychronously. If
+// maxWorkers is 0, then GOMAXPROCS number of workers will be used.
+//
+// If an error occurs during a batch, all the worker's contexts are cancelled
+// and the original error is returned.
+func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error {
+ if n < 0 {
+ return errors.New("cannot batch items of length < 0")
+ } else if n == 0 {
+ // Batching zero items is a noop.
+ return nil
+ }
+
+ if batchSize < 1 {
+ return errors.New("cannot batch items with batch size < 1")
+ }
+
+ if maxWorkers < 0 {
+ return errors.New("cannot batch items with workers < 0")
+ } else if maxWorkers == 0 {
+ maxWorkers = runtime.GOMAXPROCS(0)
+ }
+
+ sem := semaphore.NewWeighted(int64(maxWorkers))
+ g, ctx := errgroup.WithContext(ctx)
+ numBatches := (n + batchSize - 1) / batchSize
+ for i := 0; i < numBatches; i++ {
+ if err := sem.Acquire(ctx, 1); err != nil {
+ return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err)
+ }
+
+ batchNum := i
+ g.Go(func() error {
+ defer sem.Release(1)
+ start := batchNum * batchSize
+ end := minimum(start+batchSize, n)
+ return each(ctx, batchNum, start, end)
+ })
+ }
+ return g.Wait()
+}
diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go
new file mode 100644
index 0000000..c6537b9
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go
@@ -0,0 +1,164 @@
+package grpcutil
+
+import (
+ "context"
+ "errors"
+ "io"
+ "sync"
+ "time"
+
+ "github.com/rs/zerolog/log"
+ "golang.org/x/mod/semver"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ "github.com/authzed/spicedb/pkg/releases"
+)
+
+// Compile-time assertion that LogDispatchTrailers and CheckServerVersion implement the
+// grpc.UnaryClientInterceptor interface.
+var (
+ _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(LogDispatchTrailers)
+ _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(CheckServerVersion)
+)
+
+var once sync.Once
+
+// CheckServerVersion implements a gRPC unary interceptor that requests the server version
+// from SpiceDB and, if found, compares it to the current released version.
+func CheckServerVersion(
+ ctx context.Context,
+ method string,
+ req, reply interface{},
+ cc *grpc.ClientConn,
+ invoker grpc.UnaryInvoker,
+ callOpts ...grpc.CallOption,
+) error {
+ var headerMD metadata.MD
+ ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestServerVersion)
+ err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Header(&headerMD))...)
+ if err != nil {
+ return err
+ }
+
+ once.Do(func() {
+ version := headerMD.Get(string(responsemeta.ServerVersion))
+ if len(version) == 0 {
+ log.Debug().Msg("error reading server version response header; it may be disabled on the server")
+ } else if len(version) == 1 {
+ currentVersion := version[0]
+
+ // If there is a build on the version, then do not compare.
+ if semver.Build(currentVersion) != "" {
+ log.Debug().Str("this-version", currentVersion).Msg("received build version of SpiceDB")
+ return
+ }
+
+ rctx, cancel := context.WithTimeout(ctx, time.Second*2)
+ defer cancel()
+
+ state, _, release, cerr := releases.CheckIsLatestVersion(rctx, func() (string, error) {
+ return currentVersion, nil
+ }, releases.GetLatestRelease)
+ if cerr != nil {
+ log.Debug().Err(cerr).Msg("error looking up currently released version")
+ } else {
+ switch state {
+ case releases.UnreleasedVersion:
+ log.Warn().Str("version", currentVersion).Msg("not calling a released version of SpiceDB")
+ return
+
+ case releases.UpdateAvailable:
+ log.Warn().Str("this-version", currentVersion).Str("latest-released-version", release.Version).Msgf("the version of SpiceDB being called is out of date. See: %s", release.ViewURL)
+ return
+
+ case releases.UpToDate:
+ log.Debug().Str("latest-released-version", release.Version).Msg("the version of SpiceDB being called is the latest released version")
+ return
+
+ case releases.Unknown:
+ log.Warn().Str("unknown-released-version", release.Version).Msg("unable to check for a new SpiceDB version")
+ return
+
+ default:
+ panic("Unknown state for CheckAndLogRunE")
+ }
+ }
+ }
+ })
+
+ return nil
+}
+
+// LogDispatchTrailers implements a gRPC unary interceptor that logs the
+// dispatch metadata that is present in response trailers from SpiceDB.
+func LogDispatchTrailers(
+ ctx context.Context,
+ method string,
+ req, reply interface{},
+ cc *grpc.ClientConn,
+ invoker grpc.UnaryInvoker,
+ callOpts ...grpc.CallOption,
+) error {
+ var trailerMD metadata.MD
+ err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Trailer(&trailerMD))...)
+ outputDispatchTrailers(trailerMD)
+ return err
+}
+
+func outputDispatchTrailers(trailerMD metadata.MD) {
+ log.Trace().Interface("trailers", trailerMD).Msg("parsed trailers")
+
+ dispatchCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata(
+ trailerMD,
+ responsemeta.DispatchedOperationsCount,
+ )
+ if trailerErr != nil {
+ log.Debug().Err(trailerErr).Msg("error reading dispatched operations trailer")
+ }
+
+ cachedCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata(
+ trailerMD,
+ responsemeta.CachedOperationsCount,
+ )
+ if trailerErr != nil {
+ log.Debug().Err(trailerErr).Msg("error reading cached operations trailer")
+ }
+
+ log.Debug().
+ Int("dispatch", dispatchCount).
+ Int("cached", cachedCount).
+ Msg("extracted response dispatch metadata")
+}
+
+// StreamLogDispatchTrailers implements a gRPC stream interceptor that logs the
+// dispatch metadata that is present in response trailers from SpiceDB.
+func StreamLogDispatchTrailers(
+ ctx context.Context,
+ desc *grpc.StreamDesc,
+ cc *grpc.ClientConn,
+ method string,
+ streamer grpc.Streamer,
+ callOpts ...grpc.CallOption,
+) (grpc.ClientStream, error) {
+ stream, err := streamer(ctx, desc, cc, method, callOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &wrappedStream{stream}, nil
+}
+
+type wrappedStream struct {
+ grpc.ClientStream
+}
+
+func (w *wrappedStream) RecvMsg(m interface{}) error {
+ err := w.ClientStream.RecvMsg(m)
+ if err != nil && errors.Is(err, io.EOF) {
+ outputDispatchTrailers(w.Trailer())
+ }
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/printers/debug.go b/vendor/github.com/authzed/zed/internal/printers/debug.go
new file mode 100644
index 0000000..2dbb9a6
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/printers/debug.go
@@ -0,0 +1,197 @@
+package printers
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/gookit/color"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// DisplayCheckTrace prints out the check trace found in the given debug message.
+func DisplayCheckTrace(checkTrace *v1.CheckDebugTrace, tp *TreePrinter, hasError bool) {
+ displayCheckTrace(checkTrace, tp, hasError, map[string]struct{}{})
+}
+
+func displayCheckTrace(checkTrace *v1.CheckDebugTrace, tp *TreePrinter, hasError bool, encountered map[string]struct{}) {
+ red := color.FgRed.Render
+ green := color.FgGreen.Render
+ cyan := color.FgCyan.Render
+ white := color.FgWhite.Render
+ faint := color.FgGray.Render
+ magenta := color.FgMagenta.Render
+ yellow := color.FgYellow.Render
+
+ orange := color.C256(166).Sprint
+ purple := color.C256(99).Sprint
+ lightgreen := color.C256(35).Sprint
+ caveatColor := color.C256(198).Sprint
+
+ hasPermission := green("✓")
+ resourceColor := white
+ permissionColor := color.FgWhite.Render
+
+ switch checkTrace.PermissionType {
+ case v1.CheckDebugTrace_PERMISSION_TYPE_PERMISSION:
+ permissionColor = lightgreen
+ case v1.CheckDebugTrace_PERMISSION_TYPE_RELATION:
+ permissionColor = orange
+ }
+
+ switch checkTrace.Result {
+ case v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION:
+ switch checkTrace.CaveatEvaluationInfo.Result {
+ case v1.CaveatEvalInfo_RESULT_FALSE:
+ hasPermission = red("⨉")
+ resourceColor = faint
+ permissionColor = faint
+
+ case v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT:
+ hasPermission = magenta("?")
+ resourceColor = faint
+ permissionColor = faint
+ }
+ case v1.CheckDebugTrace_PERMISSIONSHIP_NO_PERMISSION:
+ hasPermission = red("⨉")
+ resourceColor = faint
+ permissionColor = faint
+ case v1.CheckDebugTrace_PERMISSIONSHIP_UNSPECIFIED:
+ hasPermission = yellow("∵")
+ }
+
+ additional := ""
+ if checkTrace.GetWasCachedResult() {
+ sourceKind := ""
+ source := checkTrace.Source
+ if source != "" {
+ parts := strings.Split(source, ":")
+ if len(parts) > 0 {
+ sourceKind = parts[0]
+ }
+ }
+ switch sourceKind {
+ case "":
+ additional = cyan(" (cached)")
+
+ case "spicedb":
+ additional = cyan(" (cached by spicedb)")
+
+ case "materialize":
+ additional = purple(" (cached by materialize)")
+
+ default:
+ additional = cyan(fmt.Sprintf(" (cached by %s)", sourceKind))
+ }
+ } else if hasError && isPartOfCycle(checkTrace, map[string]struct{}{}) {
+ hasPermission = orange("!")
+ resourceColor = white
+ }
+
+ isEndOfCycle := false
+ if hasError {
+ key := cycleKey(checkTrace)
+ _, isEndOfCycle = encountered[key]
+ if isEndOfCycle {
+ additional = color.C256(166).Sprint(" (cycle)")
+ }
+ encountered[key] = struct{}{}
+ }
+
+ timing := ""
+ if checkTrace.Duration != nil {
+ timing = fmt.Sprintf(" (%s)", checkTrace.Duration.AsDuration().String())
+ }
+
+ tp = tp.Child(
+ fmt.Sprintf(
+ "%s %s:%s %s%s%s",
+ hasPermission,
+ resourceColor(checkTrace.Resource.ObjectType),
+ resourceColor(checkTrace.Resource.ObjectId),
+ permissionColor(checkTrace.Permission),
+ additional,
+ timing,
+ ),
+ )
+
+ if isEndOfCycle {
+ return
+ }
+
+ if checkTrace.GetCaveatEvaluationInfo() != nil {
+ indicator := ""
+ exprColor := color.FgWhite.Render
+ switch checkTrace.CaveatEvaluationInfo.Result {
+ case v1.CaveatEvalInfo_RESULT_FALSE:
+ indicator = red("⨉")
+ exprColor = faint
+
+ case v1.CaveatEvalInfo_RESULT_TRUE:
+ indicator = green("✓")
+
+ case v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT:
+ indicator = magenta("?")
+ }
+
+ white := color.HEXStyle("fff")
+ white.SetOpts(color.Opts{color.OpItalic})
+
+ contextMap := checkTrace.CaveatEvaluationInfo.Context.AsMap()
+ caveatName := checkTrace.CaveatEvaluationInfo.CaveatName
+
+ c := tp.Child(fmt.Sprintf("%s %s %s", indicator, exprColor(checkTrace.CaveatEvaluationInfo.Expression), caveatColor(caveatName)))
+ if len(contextMap) > 0 {
+ contextJSON, _ := json.MarshalIndent(contextMap, "", " ")
+ c.Child(string(contextJSON))
+ } else {
+ if checkTrace.CaveatEvaluationInfo.Result != v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT {
+ c.Child(faint("(no matching context found)"))
+ }
+ }
+
+ if checkTrace.CaveatEvaluationInfo.Result == v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT {
+ c.Child(fmt.Sprintf("missing context: %s", strings.Join(checkTrace.CaveatEvaluationInfo.PartialCaveatInfo.MissingRequiredContext, ", ")))
+ }
+ }
+
+ if checkTrace.GetSubProblems() != nil {
+ for _, subProblem := range checkTrace.GetSubProblems().Traces {
+ displayCheckTrace(subProblem, tp, hasError, encountered)
+ }
+ } else if checkTrace.Result == v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION {
+ tp.Child(purple(fmt.Sprintf("%s:%s %s", checkTrace.Subject.Object.ObjectType, checkTrace.Subject.Object.ObjectId, checkTrace.Subject.OptionalRelation)))
+ }
+}
+
+func cycleKey(checkTrace *v1.CheckDebugTrace) string {
+ return fmt.Sprintf("%s#%s", tuple.V1StringObjectRef(checkTrace.Resource), checkTrace.Permission)
+}
+
+func isPartOfCycle(checkTrace *v1.CheckDebugTrace, encountered map[string]struct{}) bool {
+ if checkTrace.GetSubProblems() == nil {
+ return false
+ }
+
+ encounteredCopy := make(map[string]struct{}, len(encountered))
+ for k, v := range encountered {
+ encounteredCopy[k] = v
+ }
+
+ key := cycleKey(checkTrace)
+ if _, ok := encounteredCopy[key]; ok {
+ return true
+ }
+
+ encounteredCopy[key] = struct{}{}
+
+ for _, subProblem := range checkTrace.GetSubProblems().Traces {
+ if isPartOfCycle(subProblem, encounteredCopy) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/vendor/github.com/authzed/zed/internal/printers/table.go b/vendor/github.com/authzed/zed/internal/printers/table.go
new file mode 100644
index 0000000..fd6f024
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/printers/table.go
@@ -0,0 +1,26 @@
+package printers
+
+import (
+ "io"
+
+ "github.com/olekukonko/tablewriter"
+)
+
+// PrintTable writes an terminal-friendly table of the values to the target.
+func PrintTable(target io.Writer, headers []string, rows [][]string) {
+ table := tablewriter.NewWriter(target)
+ table.SetHeader(headers)
+ table.SetAutoWrapText(false)
+ table.SetAutoFormatHeaders(true)
+ table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
+ table.SetAlignment(tablewriter.ALIGN_LEFT)
+ table.SetCenterSeparator("")
+ table.SetColumnSeparator("")
+ table.SetRowSeparator("")
+ table.SetHeaderLine(false)
+ table.SetBorder(false)
+ table.SetTablePadding("\t")
+ table.SetNoWhiteSpace(true)
+ table.AppendBulk(rows)
+ table.Render()
+}
diff --git a/vendor/github.com/authzed/zed/internal/printers/tree.go b/vendor/github.com/authzed/zed/internal/printers/tree.go
new file mode 100644
index 0000000..b787210
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/printers/tree.go
@@ -0,0 +1,66 @@
+package printers
+
+import (
+ "fmt"
+
+ "github.com/jzelinskie/stringz"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+)
+
+func prettySubject(subj *v1.SubjectReference) string {
+ if subj.OptionalRelation == "" {
+ return fmt.Sprintf(
+ "%s:%s",
+ stringz.TrimPrefixIndex(subj.Object.ObjectType, "/"),
+ subj.Object.ObjectId,
+ )
+ }
+ return fmt.Sprintf(
+ "%s:%s->%s",
+ stringz.TrimPrefixIndex(subj.Object.ObjectType, "/"),
+ subj.Object.ObjectId,
+ subj.OptionalRelation,
+ )
+}
+
+// TreeNodeTree walks an Authzed Tree Node and creates corresponding nodes
+// for a treeprinter.
+func TreeNodeTree(tp *TreePrinter, treeNode *v1.PermissionRelationshipTree) {
+ if treeNode.ExpandedObject != nil {
+ tp = tp.Child(fmt.Sprintf(
+ "%s:%s->%s",
+ stringz.TrimPrefixIndex(treeNode.ExpandedObject.ObjectType, "/"),
+ treeNode.ExpandedObject.ObjectId,
+ treeNode.ExpandedRelation,
+ ))
+ }
+ switch typed := treeNode.TreeType.(type) {
+ case *v1.PermissionRelationshipTree_Intermediate:
+ switch typed.Intermediate.Operation {
+ case v1.AlgebraicSubjectSet_OPERATION_UNION:
+ union := tp.Child("union")
+ for _, child := range typed.Intermediate.Children {
+ TreeNodeTree(union, child)
+ }
+ case v1.AlgebraicSubjectSet_OPERATION_INTERSECTION:
+ intersection := tp.Child("intersection")
+ for _, child := range typed.Intermediate.Children {
+ TreeNodeTree(intersection, child)
+ }
+ case v1.AlgebraicSubjectSet_OPERATION_EXCLUSION:
+ exclusion := tp.Child("exclusion")
+ for _, child := range typed.Intermediate.Children {
+ TreeNodeTree(exclusion, child)
+ }
+ default:
+ panic("unknown expand operation")
+ }
+ case *v1.PermissionRelationshipTree_Leaf:
+ for _, subject := range typed.Leaf.Subjects {
+ tp.Child(prettySubject(subject))
+ }
+ default:
+ panic("unknown TreeNode type")
+ }
+}
diff --git a/vendor/github.com/authzed/zed/internal/printers/treeprinter.go b/vendor/github.com/authzed/zed/internal/printers/treeprinter.go
new file mode 100644
index 0000000..55c3830
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/printers/treeprinter.go
@@ -0,0 +1,43 @@
+package printers
+
+import (
+ "strings"
+
+ "github.com/xlab/treeprint"
+
+ "github.com/authzed/zed/internal/console"
+)
+
+type TreePrinter struct {
+ tree treeprint.Tree
+}
+
+func NewTreePrinter() *TreePrinter {
+ return &TreePrinter{}
+}
+
+func (tp *TreePrinter) Child(val string) *TreePrinter {
+ if tp.tree == nil {
+ tp.tree = treeprint.NewWithRoot(val)
+ return tp
+ }
+ return &TreePrinter{tree: tp.tree.AddBranch(val)}
+}
+
+func (tp *TreePrinter) Print() {
+ console.Println(tp.String())
+}
+
+func (tp *TreePrinter) Indented() string {
+ var sb strings.Builder
+ lines := strings.Split(tp.String(), "\n")
+ for _, line := range lines {
+ sb.WriteString(" " + line + "\n")
+ }
+
+ return sb.String()
+}
+
+func (tp *TreePrinter) String() string {
+ return tp.tree.String()
+}
diff --git a/vendor/github.com/authzed/zed/internal/storage/config.go b/vendor/github.com/authzed/zed/internal/storage/config.go
new file mode 100644
index 0000000..1fca679
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/storage/config.go
@@ -0,0 +1,194 @@
+package storage
+
+import (
+ "encoding/json"
+ "errors"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "runtime"
+
+ "github.com/jzelinskie/stringz"
+)
+
+const configFileName = "config.json"
+
+// ErrConfigNotFound is returned if there is no Config in a ConfigStore.
+var ErrConfigNotFound = errors.New("config did not exist")
+
+// ErrTokenNotFound is returned if there is no Token in a ConfigStore.
+var ErrTokenNotFound = errors.New("token does not exist")
+
+// Config represents the contents of a zed configuration file.
+type Config struct {
+ Version string
+ CurrentToken string
+}
+
+// ConfigStore is anything that can persistently store a Config.
+type ConfigStore interface {
+ Get() (Config, error)
+ Put(Config) error
+ Exists() (bool, error)
+}
+
+// TokenWithOverride returns a Token that retrieves its values from the reference Token, and has its values overridden
+// any of the non-empty/non-nil values of the overrideToken.
+func TokenWithOverride(overrideToken Token, referenceToken Token) (Token, error) {
+ insecure := referenceToken.Insecure
+ if overrideToken.Insecure != nil {
+ insecure = overrideToken.Insecure
+ }
+
+ // done so that logging messages don't show nil for the resulting context
+ if insecure == nil {
+ bFalse := false
+ insecure = &bFalse
+ }
+
+ noVerifyCA := referenceToken.NoVerifyCA
+ if overrideToken.NoVerifyCA != nil {
+ noVerifyCA = overrideToken.NoVerifyCA
+ }
+
+ // done so that logging messages don't show nil for the resulting context
+ if noVerifyCA == nil {
+ bFalse := false
+ noVerifyCA = &bFalse
+ }
+
+ caCert := referenceToken.CACert
+ if overrideToken.CACert != nil {
+ caCert = overrideToken.CACert
+ }
+
+ return Token{
+ Name: referenceToken.Name,
+ Endpoint: stringz.DefaultEmpty(overrideToken.Endpoint, referenceToken.Endpoint),
+ APIToken: stringz.DefaultEmpty(overrideToken.APIToken, referenceToken.APIToken),
+ Insecure: insecure,
+ NoVerifyCA: noVerifyCA,
+ CACert: caCert,
+ }, nil
+}
+
+// CurrentToken is a convenient way to obtain the CurrentToken field from the
+// current Config.
+func CurrentToken(cs ConfigStore, ss SecretStore) (token Token, err error) {
+ cfg, err := cs.Get()
+ if err != nil {
+ return Token{}, err
+ }
+
+ return GetTokenIfExists(cfg.CurrentToken, ss)
+}
+
+// SetCurrentToken is a convenient way to set the CurrentToken field in a
+// the current config.
+func SetCurrentToken(name string, cs ConfigStore, ss SecretStore) error {
+ // Ensure the token exists
+ exists, err := TokenExists(name, ss)
+ if err != nil {
+ return err
+ }
+
+ if !exists {
+ return ErrTokenNotFound
+ }
+
+ cfg, err := cs.Get()
+ if err != nil {
+ if errors.Is(err, ErrConfigNotFound) {
+ cfg = Config{Version: "v1"}
+ } else {
+ return err
+ }
+ }
+
+ cfg.CurrentToken = name
+ return cs.Put(cfg)
+}
+
+// JSONConfigStore implements a ConfigStore that stores its Config in a JSON file at the provided ConfigPath.
+type JSONConfigStore struct {
+ ConfigPath string
+}
+
+// Enforce that our implementation satisfies the interface.
+var _ ConfigStore = JSONConfigStore{}
+
+// Get parses a Config from the filesystem.
+func (s JSONConfigStore) Get() (Config, error) {
+ cfgBytes, err := os.ReadFile(filepath.Join(s.ConfigPath, configFileName))
+ if errors.Is(err, fs.ErrNotExist) {
+ return Config{}, ErrConfigNotFound
+ } else if err != nil {
+ return Config{}, err
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(cfgBytes, &cfg); err != nil {
+ return Config{}, err
+ }
+
+ return cfg, nil
+}
+
+// Put overwrites a Config on the filesystem.
+func (s JSONConfigStore) Put(cfg Config) error {
+ if err := os.MkdirAll(s.ConfigPath, 0o774); err != nil {
+ return err
+ }
+
+ cfgBytes, err := json.Marshal(cfg)
+ if err != nil {
+ return err
+ }
+
+ return atomicWriteFile(filepath.Join(s.ConfigPath, configFileName), cfgBytes, 0o774)
+}
+
+func (s JSONConfigStore) Exists() (bool, error) {
+ if _, err := os.Stat(filepath.Join(s.ConfigPath, configFileName)); errors.Is(err, fs.ErrNotExist) {
+ return false, nil
+ } else if err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+// atomicWriteFile writes data to filename+some suffix, then renames it into
+// filename.
+//
+// Copyright (c) 2019 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be found
+// at the following URL:
+// https://github.com/tailscale/tailscale/blob/main/LICENSE
+func atomicWriteFile(filename string, data []byte, perm os.FileMode) (err error) {
+ f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp")
+ if err != nil {
+ return err
+ }
+ tmpName := f.Name()
+ defer func() {
+ if err != nil {
+ f.Close()
+ os.Remove(tmpName)
+ }
+ }()
+ if _, err := f.Write(data); err != nil {
+ return err
+ }
+ if runtime.GOOS != "windows" {
+ if err := f.Chmod(perm); err != nil {
+ return err
+ }
+ }
+ if err := f.Sync(); err != nil {
+ return err
+ }
+ if err := f.Close(); err != nil {
+ return err
+ }
+ return os.Rename(tmpName, filename)
+}
diff --git a/vendor/github.com/authzed/zed/internal/storage/secrets.go b/vendor/github.com/authzed/zed/internal/storage/secrets.go
new file mode 100644
index 0000000..3436f6b
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/storage/secrets.go
@@ -0,0 +1,265 @@
+package storage
+
+import (
+ "encoding/json"
+ "errors"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/99designs/keyring"
+ "github.com/charmbracelet/x/term"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/zed/internal/console"
+)
+
+type Token struct {
+ Name string
+ Endpoint string
+ APIToken string
+ Insecure *bool
+ NoVerifyCA *bool
+ CACert []byte
+}
+
+func (t Token) AnyValue() bool {
+ if t.Endpoint != "" || t.APIToken != "" || t.Insecure != nil || t.NoVerifyCA != nil || len(t.CACert) > 0 {
+ return true
+ }
+
+ return false
+}
+
+func (t Token) Certificate() (cert []byte, ok bool) {
+ if len(t.CACert) > 0 {
+ return t.CACert, true
+ }
+ return nil, false
+}
+
+func (t Token) IsInsecure() bool {
+ return t.Insecure != nil && *t.Insecure
+}
+
+func (t Token) HasNoVerifyCA() bool {
+ return t.NoVerifyCA != nil && *t.NoVerifyCA
+}
+
+func (t Token) Redacted() string {
+ prefix, _ := t.SplitAPIToken()
+ if prefix == "" {
+ return "<redacted>"
+ }
+
+ return stringz.Join("_", prefix, "<redacted>")
+}
+
+func (t Token) SplitAPIToken() (prefix, secret string) {
+ exploded := strings.Split(t.APIToken, "_")
+ return strings.Join(exploded[:len(exploded)-1], "_"), exploded[len(exploded)-1]
+}
+
+type Secrets struct {
+ Tokens []Token
+}
+
+type SecretStore interface {
+ Get() (Secrets, error)
+ Put(s Secrets) error
+}
+
+// GetTokenIfExists returns an empty token if no token exists.
+func GetTokenIfExists(name string, ss SecretStore) (Token, error) {
+ secrets, err := ss.Get()
+ if err != nil {
+ return Token{}, err
+ }
+
+ for _, token := range secrets.Tokens {
+ if name == token.Name {
+ return token, nil
+ }
+ }
+
+ return Token{}, nil
+}
+
+func TokenExists(name string, ss SecretStore) (bool, error) {
+ secrets, err := ss.Get()
+ if err != nil {
+ return false, err
+ }
+
+ for _, token := range secrets.Tokens {
+ if name == token.Name {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
+func PutToken(t Token, ss SecretStore) error {
+ secrets, err := ss.Get()
+ if err != nil {
+ return err
+ }
+
+ replaced := false
+ for i, token := range secrets.Tokens {
+ if token.Name == t.Name {
+ secrets.Tokens[i] = t
+ replaced = true
+ }
+ }
+
+ if !replaced {
+ secrets.Tokens = append(secrets.Tokens, t)
+ }
+
+ return ss.Put(secrets)
+}
+
+func RemoveToken(name string, ss SecretStore) error {
+ secrets, err := ss.Get()
+ if err != nil {
+ return err
+ }
+
+ for i, token := range secrets.Tokens {
+ if token.Name == name {
+ secrets.Tokens = append(secrets.Tokens[:i], secrets.Tokens[i+1:]...)
+ break
+ }
+ }
+
+ return ss.Put(secrets)
+}
+
+type KeychainSecretStore struct {
+ ConfigPath string
+ ring keyring.Keyring
+}
+
+var _ SecretStore = (*KeychainSecretStore)(nil)
+
+const (
+ svcName = "zed"
+ keyringEntryName = svcName + " secrets"
+ envRecommendation = "Setting the environment variable `ZED_KEYRING_PASSWORD` to your password will skip prompts.\n"
+ keyringDoesNotExistPrompt = "Keyring file does not already exist.\nEnter a new non-empty passphrase for the new keyring file: "
+ keyringPrompt = "Enter passphrase to unlock zed keyring: "
+ emptyKeyringPasswordError = "your passphrase must not be empty"
+)
+
+func fileExists(path string) (bool, error) {
+ _, err := os.Stat(path)
+ switch {
+ case err == nil:
+ return true, nil
+ case os.IsNotExist(err):
+ return false, nil
+ default:
+ return false, err
+ }
+}
+
+func promptPassword(prompt string) (string, error) {
+ console.Printf(prompt)
+ b, err := term.ReadPassword(os.Stdin.Fd())
+ if err != nil {
+ return "", err
+ }
+ console.Printf("\n") // Clear the line after a prompt
+ return string(b), err
+}
+
+func (k *KeychainSecretStore) keyring() (keyring.Keyring, error) {
+ if k.ring != nil {
+ return k.ring, nil
+ }
+
+ keyringPath := filepath.Join(k.ConfigPath, "keyring.jwt")
+
+ ring, err := keyring.Open(keyring.Config{
+ ServiceName: "zed",
+ FileDir: keyringPath,
+ FilePasswordFunc: func(_ string) (string, error) {
+ if password, ok := os.LookupEnv("ZED_KEYRING_PASSWORD"); ok {
+ return password, nil
+ }
+
+ // Check if this is the first run where the keyring is created.
+ keyringExists, err := fileExists(filepath.Join(keyringPath, keyringEntryName))
+ if err != nil {
+ return "", err
+ }
+ if !keyringExists {
+ // This is the first run and we're creating a password.
+ passwordString, err := promptPassword(envRecommendation + keyringDoesNotExistPrompt)
+ if err != nil {
+ return "", err
+ }
+
+ if len(passwordString) == 0 {
+ // NOTE: we enforce a non-empty keyring password to prevent
+ // user frustration around accidentally setting an empty
+ // passphrase and then not knowing what it might be.
+ return "", errors.New(emptyKeyringPasswordError)
+ }
+
+ return passwordString, nil
+ }
+
+ passwordString, err := promptPassword(envRecommendation + keyringPrompt)
+ if err != nil {
+ return "", err
+ }
+
+ return passwordString, nil
+ },
+ })
+ if err != nil {
+ return ring, err
+ }
+
+ k.ring = ring
+ return ring, err
+}
+
+func (k *KeychainSecretStore) Get() (Secrets, error) {
+ ring, err := k.keyring()
+ if err != nil {
+ return Secrets{}, err
+ }
+
+ entry, err := ring.Get(keyringEntryName)
+ if err != nil {
+ if errors.Is(err, keyring.ErrKeyNotFound) {
+ return Secrets{}, nil // empty is okay!
+ }
+ return Secrets{}, err
+ }
+
+ var s Secrets
+ err = json.Unmarshal(entry.Data, &s)
+ return s, err
+}
+
+func (k *KeychainSecretStore) Put(s Secrets) error {
+ ring, err := k.keyring()
+ if err != nil {
+ return err
+ }
+
+ data, err := json.Marshal(s)
+ if err != nil {
+ return err
+ }
+
+ return ring.Set(keyring.Item{
+ Key: keyringEntryName,
+ Data: data,
+ })
+}