summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal/grpcutil
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/grpcutil
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/grpcutil')
-rw-r--r--vendor/github.com/authzed/zed/internal/grpcutil/batch.go68
-rw-r--r--vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go164
2 files changed, 232 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/batch.go b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go
new file mode 100644
index 0000000..640085c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go
@@ -0,0 +1,68 @@
+package grpcutil
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "runtime"
+
+ "golang.org/x/sync/errgroup"
+ "golang.org/x/sync/semaphore"
+)
+
+func minimum(a int, b int) int {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+// EachFunc is a callback function that is called for each batch. no is the
+// batch number, start is the starting index of this batch in the slice, and
+// end is the ending index of this batch in the slice.
+type EachFunc func(ctx context.Context, no int, start int, end int) error
+
+// ConcurrentBatch will calculate the minimum number of batches to required to batch n items
+// with batchSize batches. For each batch, it will execute the each function.
+// These functions will be processed in parallel using maxWorkers number of
+// goroutines. If maxWorkers is 1, then batching will happen sychronously. If
+// maxWorkers is 0, then GOMAXPROCS number of workers will be used.
+//
+// If an error occurs during a batch, all the worker's contexts are cancelled
+// and the original error is returned.
+func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error {
+ if n < 0 {
+ return errors.New("cannot batch items of length < 0")
+ } else if n == 0 {
+ // Batching zero items is a noop.
+ return nil
+ }
+
+ if batchSize < 1 {
+ return errors.New("cannot batch items with batch size < 1")
+ }
+
+ if maxWorkers < 0 {
+ return errors.New("cannot batch items with workers < 0")
+ } else if maxWorkers == 0 {
+ maxWorkers = runtime.GOMAXPROCS(0)
+ }
+
+ sem := semaphore.NewWeighted(int64(maxWorkers))
+ g, ctx := errgroup.WithContext(ctx)
+ numBatches := (n + batchSize - 1) / batchSize
+ for i := 0; i < numBatches; i++ {
+ if err := sem.Acquire(ctx, 1); err != nil {
+ return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err)
+ }
+
+ batchNum := i
+ g.Go(func() error {
+ defer sem.Release(1)
+ start := batchNum * batchSize
+ end := minimum(start+batchSize, n)
+ return each(ctx, batchNum, start, end)
+ })
+ }
+ return g.Wait()
+}
diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go
new file mode 100644
index 0000000..c6537b9
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go
@@ -0,0 +1,164 @@
+package grpcutil
+
+import (
+ "context"
+ "errors"
+ "io"
+ "sync"
+ "time"
+
+ "github.com/rs/zerolog/log"
+ "golang.org/x/mod/semver"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ "github.com/authzed/spicedb/pkg/releases"
+)
+
+// Compile-time assertion that LogDispatchTrailers and CheckServerVersion implement the
+// grpc.UnaryClientInterceptor interface.
+var (
+ _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(LogDispatchTrailers)
+ _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(CheckServerVersion)
+)
+
+var once sync.Once
+
+// CheckServerVersion implements a gRPC unary interceptor that requests the server version
+// from SpiceDB and, if found, compares it to the current released version.
+func CheckServerVersion(
+ ctx context.Context,
+ method string,
+ req, reply interface{},
+ cc *grpc.ClientConn,
+ invoker grpc.UnaryInvoker,
+ callOpts ...grpc.CallOption,
+) error {
+ var headerMD metadata.MD
+ ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestServerVersion)
+ err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Header(&headerMD))...)
+ if err != nil {
+ return err
+ }
+
+ once.Do(func() {
+ version := headerMD.Get(string(responsemeta.ServerVersion))
+ if len(version) == 0 {
+ log.Debug().Msg("error reading server version response header; it may be disabled on the server")
+ } else if len(version) == 1 {
+ currentVersion := version[0]
+
+ // If there is a build on the version, then do not compare.
+ if semver.Build(currentVersion) != "" {
+ log.Debug().Str("this-version", currentVersion).Msg("received build version of SpiceDB")
+ return
+ }
+
+ rctx, cancel := context.WithTimeout(ctx, time.Second*2)
+ defer cancel()
+
+ state, _, release, cerr := releases.CheckIsLatestVersion(rctx, func() (string, error) {
+ return currentVersion, nil
+ }, releases.GetLatestRelease)
+ if cerr != nil {
+ log.Debug().Err(cerr).Msg("error looking up currently released version")
+ } else {
+ switch state {
+ case releases.UnreleasedVersion:
+ log.Warn().Str("version", currentVersion).Msg("not calling a released version of SpiceDB")
+ return
+
+ case releases.UpdateAvailable:
+ log.Warn().Str("this-version", currentVersion).Str("latest-released-version", release.Version).Msgf("the version of SpiceDB being called is out of date. See: %s", release.ViewURL)
+ return
+
+ case releases.UpToDate:
+ log.Debug().Str("latest-released-version", release.Version).Msg("the version of SpiceDB being called is the latest released version")
+ return
+
+ case releases.Unknown:
+ log.Warn().Str("unknown-released-version", release.Version).Msg("unable to check for a new SpiceDB version")
+ return
+
+ default:
+ panic("Unknown state for CheckAndLogRunE")
+ }
+ }
+ }
+ })
+
+ return nil
+}
+
+// LogDispatchTrailers implements a gRPC unary interceptor that logs the
+// dispatch metadata that is present in response trailers from SpiceDB.
+func LogDispatchTrailers(
+ ctx context.Context,
+ method string,
+ req, reply interface{},
+ cc *grpc.ClientConn,
+ invoker grpc.UnaryInvoker,
+ callOpts ...grpc.CallOption,
+) error {
+ var trailerMD metadata.MD
+ err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Trailer(&trailerMD))...)
+ outputDispatchTrailers(trailerMD)
+ return err
+}
+
+func outputDispatchTrailers(trailerMD metadata.MD) {
+ log.Trace().Interface("trailers", trailerMD).Msg("parsed trailers")
+
+ dispatchCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata(
+ trailerMD,
+ responsemeta.DispatchedOperationsCount,
+ )
+ if trailerErr != nil {
+ log.Debug().Err(trailerErr).Msg("error reading dispatched operations trailer")
+ }
+
+ cachedCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata(
+ trailerMD,
+ responsemeta.CachedOperationsCount,
+ )
+ if trailerErr != nil {
+ log.Debug().Err(trailerErr).Msg("error reading cached operations trailer")
+ }
+
+ log.Debug().
+ Int("dispatch", dispatchCount).
+ Int("cached", cachedCount).
+ Msg("extracted response dispatch metadata")
+}
+
+// StreamLogDispatchTrailers implements a gRPC stream interceptor that logs the
+// dispatch metadata that is present in response trailers from SpiceDB.
+func StreamLogDispatchTrailers(
+ ctx context.Context,
+ desc *grpc.StreamDesc,
+ cc *grpc.ClientConn,
+ method string,
+ streamer grpc.Streamer,
+ callOpts ...grpc.CallOption,
+) (grpc.ClientStream, error) {
+ stream, err := streamer(ctx, desc, cc, method, callOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &wrappedStream{stream}, nil
+}
+
+type wrappedStream struct {
+ grpc.ClientStream
+}
+
+func (w *wrappedStream) RecvMsg(m interface{}) error {
+ err := w.ClientStream.RecvMsg(m)
+ if err != nil && errors.Is(err, io.EOF) {
+ outputDispatchTrailers(w.Trailer())
+ }
+ return err
+}