summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal/cmd/import.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/cmd/import.go
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd/import.go')
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/import.go181
1 files changed, 181 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/import.go b/vendor/github.com/authzed/zed/internal/cmd/import.go
new file mode 100644
index 0000000..687d42e
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/import.go
@@ -0,0 +1,181 @@
+package cmd
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/grpcutil"
+)
+
+func registerImportCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(importCmd)
+ importCmd.Flags().Int("batch-size", 1000, "import batch size")
+ importCmd.Flags().Int("workers", 1, "number of concurrent batching workers")
+ importCmd.Flags().Bool("schema", true, "import schema")
+ importCmd.Flags().Bool("relationships", true, "import relationships")
+ importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing")
+}
+
+var importCmd = &cobra.Command{
+ Use: "import <url>",
+ Short: "Imports schema and relationships from a file or url",
+ Example: `
+ From a gist:
+ zed import https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed import https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed import https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed import https://localhost:8443/download
+
+ From a local file (with prefix):
+ zed import file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed import authzed-x7izWU8_2Gw3.yaml
+
+ Only schema:
+ zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ Only relationships:
+ zed import --schema=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ With schema definition prefix:
+ zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+`,
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: importCmdFunc,
+}
+
+func importCmdFunc(cmd *cobra.Command, args []string) error {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ u, err := url.Parse(args[0])
+ if err != nil {
+ return err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return err
+ }
+ var p validationfile.ValidationFile
+ if _, _, err := decoder(&p); err != nil {
+ return err
+ }
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "schema") {
+ if err := importSchema(cmd.Context(), client, p.Schema.Schema, prefix); err != nil {
+ return err
+ }
+ }
+
+ if cobrautil.MustGetBool(cmd, "relationships") {
+ batchSize := cobrautil.MustGetInt(cmd, "batch-size")
+ workers := cobrautil.MustGetInt(cmd, "workers")
+ if err := importRelationships(cmd.Context(), client, p.Relationships.RelationshipsString, prefix, batchSize, workers); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+func importSchema(ctx context.Context, client client.Client, schema string, definitionPrefix string) error {
+ log.Info().Msg("importing schema")
+
+ // Recompile the schema with the specified prefix.
+ schemaText, err := rewriteSchema(schema, definitionPrefix)
+ if err != nil {
+ return err
+ }
+
+ // Write the recompiled and regenerated schema.
+ request := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema")
+
+ if _, err := client.WriteSchema(ctx, request); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func importRelationships(ctx context.Context, client client.Client, relationships string, definitionPrefix string, batchSize int, workers int) error {
+ relationshipUpdates := make([]*v1.RelationshipUpdate, 0)
+ scanner := bufio.NewScanner(strings.NewReader(relationships))
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+ if strings.HasPrefix(line, "//") {
+ continue
+ }
+ rel, err := tuple.ParseV1Rel(line)
+ if err != nil {
+ return fmt.Errorf("failed to parse %s as relationship", line)
+ }
+ log.Trace().Str("line", line).Send()
+
+ // Rewrite the prefix on the references, if any.
+ if len(definitionPrefix) > 0 {
+ rel.Resource.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Resource.ObjectType)
+ rel.Subject.Object.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Subject.Object.ObjectType)
+ }
+
+ relationshipUpdates = append(relationshipUpdates, &v1.RelationshipUpdate{
+ Operation: v1.RelationshipUpdate_OPERATION_TOUCH,
+ Relationship: rel,
+ })
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_size", batchSize).
+ Int("workers", workers).
+ Int("count", len(relationshipUpdates)).
+ Msg("importing relationships")
+
+ err := grpcutil.ConcurrentBatch(ctx, len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error {
+ request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]}
+ _, err := client.WriteRelationships(ctx, request)
+ if err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_no", no).
+ Int("count", len(relationshipUpdates[start:end])).
+ Msg("imported relationships")
+ return nil
+ })
+ return err
+}