diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/grpcutil | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/grpcutil')
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/grpcutil/batch.go | 68 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go | 164 |
2 files changed, 232 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/batch.go b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go new file mode 100644 index 0000000..640085c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/grpcutil/batch.go @@ -0,0 +1,68 @@ +package grpcutil + +import ( + "context" + "errors" + "fmt" + "runtime" + + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +func minimum(a int, b int) int { + if a <= b { + return a + } + return b +} + +// EachFunc is a callback function that is called for each batch. no is the +// batch number, start is the starting index of this batch in the slice, and +// end is the ending index of this batch in the slice. +type EachFunc func(ctx context.Context, no int, start int, end int) error + +// ConcurrentBatch will calculate the minimum number of batches to required to batch n items +// with batchSize batches. For each batch, it will execute the each function. +// These functions will be processed in parallel using maxWorkers number of +// goroutines. If maxWorkers is 1, then batching will happen sychronously. If +// maxWorkers is 0, then GOMAXPROCS number of workers will be used. +// +// If an error occurs during a batch, all the worker's contexts are cancelled +// and the original error is returned. +func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error { + if n < 0 { + return errors.New("cannot batch items of length < 0") + } else if n == 0 { + // Batching zero items is a noop. + return nil + } + + if batchSize < 1 { + return errors.New("cannot batch items with batch size < 1") + } + + if maxWorkers < 0 { + return errors.New("cannot batch items with workers < 0") + } else if maxWorkers == 0 { + maxWorkers = runtime.GOMAXPROCS(0) + } + + sem := semaphore.NewWeighted(int64(maxWorkers)) + g, ctx := errgroup.WithContext(ctx) + numBatches := (n + batchSize - 1) / batchSize + for i := 0; i < numBatches; i++ { + if err := sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err) + } + + batchNum := i + g.Go(func() error { + defer sem.Release(1) + start := batchNum * batchSize + end := minimum(start+batchSize, n) + return each(ctx, batchNum, start, end) + }) + } + return g.Wait() +} diff --git a/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go new file mode 100644 index 0000000..c6537b9 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/grpcutil/grpcutil.go @@ -0,0 +1,164 @@ +package grpcutil + +import ( + "context" + "errors" + "io" + "sync" + "time" + + "github.com/rs/zerolog/log" + "golang.org/x/mod/semver" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/authzed/authzed-go/pkg/requestmeta" + "github.com/authzed/authzed-go/pkg/responsemeta" + "github.com/authzed/spicedb/pkg/releases" +) + +// Compile-time assertion that LogDispatchTrailers and CheckServerVersion implement the +// grpc.UnaryClientInterceptor interface. +var ( + _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(LogDispatchTrailers) + _ grpc.UnaryClientInterceptor = grpc.UnaryClientInterceptor(CheckServerVersion) +) + +var once sync.Once + +// CheckServerVersion implements a gRPC unary interceptor that requests the server version +// from SpiceDB and, if found, compares it to the current released version. +func CheckServerVersion( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + callOpts ...grpc.CallOption, +) error { + var headerMD metadata.MD + ctx = requestmeta.AddRequestHeaders(ctx, requestmeta.RequestServerVersion) + err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Header(&headerMD))...) + if err != nil { + return err + } + + once.Do(func() { + version := headerMD.Get(string(responsemeta.ServerVersion)) + if len(version) == 0 { + log.Debug().Msg("error reading server version response header; it may be disabled on the server") + } else if len(version) == 1 { + currentVersion := version[0] + + // If there is a build on the version, then do not compare. + if semver.Build(currentVersion) != "" { + log.Debug().Str("this-version", currentVersion).Msg("received build version of SpiceDB") + return + } + + rctx, cancel := context.WithTimeout(ctx, time.Second*2) + defer cancel() + + state, _, release, cerr := releases.CheckIsLatestVersion(rctx, func() (string, error) { + return currentVersion, nil + }, releases.GetLatestRelease) + if cerr != nil { + log.Debug().Err(cerr).Msg("error looking up currently released version") + } else { + switch state { + case releases.UnreleasedVersion: + log.Warn().Str("version", currentVersion).Msg("not calling a released version of SpiceDB") + return + + case releases.UpdateAvailable: + log.Warn().Str("this-version", currentVersion).Str("latest-released-version", release.Version).Msgf("the version of SpiceDB being called is out of date. See: %s", release.ViewURL) + return + + case releases.UpToDate: + log.Debug().Str("latest-released-version", release.Version).Msg("the version of SpiceDB being called is the latest released version") + return + + case releases.Unknown: + log.Warn().Str("unknown-released-version", release.Version).Msg("unable to check for a new SpiceDB version") + return + + default: + panic("Unknown state for CheckAndLogRunE") + } + } + } + }) + + return nil +} + +// LogDispatchTrailers implements a gRPC unary interceptor that logs the +// dispatch metadata that is present in response trailers from SpiceDB. +func LogDispatchTrailers( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + callOpts ...grpc.CallOption, +) error { + var trailerMD metadata.MD + err := invoker(ctx, method, req, reply, cc, append(callOpts, grpc.Trailer(&trailerMD))...) + outputDispatchTrailers(trailerMD) + return err +} + +func outputDispatchTrailers(trailerMD metadata.MD) { + log.Trace().Interface("trailers", trailerMD).Msg("parsed trailers") + + dispatchCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata( + trailerMD, + responsemeta.DispatchedOperationsCount, + ) + if trailerErr != nil { + log.Debug().Err(trailerErr).Msg("error reading dispatched operations trailer") + } + + cachedCount, trailerErr := responsemeta.GetIntResponseTrailerMetadata( + trailerMD, + responsemeta.CachedOperationsCount, + ) + if trailerErr != nil { + log.Debug().Err(trailerErr).Msg("error reading cached operations trailer") + } + + log.Debug(). + Int("dispatch", dispatchCount). + Int("cached", cachedCount). + Msg("extracted response dispatch metadata") +} + +// StreamLogDispatchTrailers implements a gRPC stream interceptor that logs the +// dispatch metadata that is present in response trailers from SpiceDB. +func StreamLogDispatchTrailers( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + callOpts ...grpc.CallOption, +) (grpc.ClientStream, error) { + stream, err := streamer(ctx, desc, cc, method, callOpts...) + if err != nil { + return nil, err + } + + return &wrappedStream{stream}, nil +} + +type wrappedStream struct { + grpc.ClientStream +} + +func (w *wrappedStream) RecvMsg(m interface{}) error { + err := w.ClientStream.RecvMsg(m) + if err != nil && errors.Is(err, io.EOF) { + outputDispatchTrailers(w.Trailer()) + } + return err +} |
