summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/dispatch/stream.go')
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/stream.go187
1 files changed, 187 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go
new file mode 100644
index 0000000..1d6636c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go
@@ -0,0 +1,187 @@
+package dispatch
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+
+ grpc "google.golang.org/grpc"
+)
+
+// Stream defines the interface generically matching a streaming dispatch response.
+type Stream[T any] interface {
+ // Publish publishes the result to the stream.
+ Publish(T) error
+
+ // Context returns the context for the stream.
+ Context() context.Context
+}
+
+type grpcStream[T any] interface {
+ grpc.ServerStream
+ Send(T) error
+}
+
+// WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is
+// necessary because gRPC response streams are *not concurrent safe*.
+// See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1
+func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] {
+ return &concurrentSafeStream[R]{
+ grpcStream: grpcStream,
+ mu: sync.Mutex{},
+ }
+}
+
+type concurrentSafeStream[T any] struct {
+ grpcStream grpcStream[T] // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+func (s *concurrentSafeStream[T]) Context() context.Context {
+ return s.grpcStream.Context()
+}
+
+func (s *concurrentSafeStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.grpcStream.Send(result)
+}
+
+// NewCollectingDispatchStream creates a new CollectingDispatchStream.
+func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] {
+ return &CollectingDispatchStream[T]{
+ ctx: ctx,
+ results: nil,
+ mu: sync.Mutex{},
+ }
+}
+
+// CollectingDispatchStream is a dispatch stream that collects results in memory.
+type CollectingDispatchStream[T any] struct {
+ ctx context.Context
+ results []T // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+func (s *CollectingDispatchStream[T]) Context() context.Context {
+ return s.ctx
+}
+
+func (s *CollectingDispatchStream[T]) Results() []T {
+ return s.results
+}
+
+func (s *CollectingDispatchStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.results = append(s.results, result)
+ return nil
+}
+
+// WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs
+// an operation on each result before puppeting back up to the parent stream.
+type WrappedDispatchStream[T any] struct {
+ Stream Stream[T]
+ Ctx context.Context
+ Processor func(result T) (T, bool, error)
+}
+
+func (s *WrappedDispatchStream[T]) Publish(result T) error {
+ if s.Processor == nil {
+ return s.Stream.Publish(result)
+ }
+
+ processed, ok, err := s.Processor(result)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ return nil
+ }
+
+ return s.Stream.Publish(processed)
+}
+
+func (s *WrappedDispatchStream[T]) Context() context.Context {
+ return s.Ctx
+}
+
+// StreamWithContext returns the given dispatch stream, wrapped to return the given context.
+func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] {
+ return &WrappedDispatchStream[T]{
+ Stream: stream,
+ Ctx: context,
+ Processor: nil,
+ }
+}
+
+// HandlingDispatchStream is a dispatch stream that executes a handler for each item published.
+// It uses an internal mutex to ensure it is thread safe.
+type HandlingDispatchStream[T any] struct {
+ ctx context.Context
+ processor func(result T) error // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+// NewHandlingDispatchStream returns a new handling dispatch stream.
+func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] {
+ return &HandlingDispatchStream[T]{
+ ctx: ctx,
+ processor: processor,
+ mu: sync.Mutex{},
+ }
+}
+
+func (s *HandlingDispatchStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.processor == nil {
+ return nil
+ }
+
+ return s.processor(result)
+}
+
+func (s *HandlingDispatchStream[T]) Context() context.Context {
+ return s.ctx
+}
+
+// CountingDispatchStream is a dispatch stream that counts the number of items published.
+// It uses an internal atomic int to ensure it is thread safe.
+type CountingDispatchStream[T any] struct {
+ Stream Stream[T]
+ count *atomic.Uint64
+}
+
+func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] {
+ return &CountingDispatchStream[T]{
+ Stream: wrapped,
+ count: &atomic.Uint64{},
+ }
+}
+
+func (s *CountingDispatchStream[T]) PublishedCount() uint64 {
+ return s.count.Load()
+}
+
+func (s *CountingDispatchStream[T]) Publish(result T) error {
+ err := s.Stream.Publish(result)
+ if err != nil {
+ return err
+ }
+
+ s.count.Add(1)
+ return nil
+}
+
+func (s *CountingDispatchStream[T]) Context() context.Context {
+ return s.Stream.Context()
+}
+
+// Ensure the streams implement the interface.
+var (
+ _ Stream[any] = &CollectingDispatchStream[any]{}
+ _ Stream[any] = &WrappedDispatchStream[any]{}
+ _ Stream[any] = &CountingDispatchStream[any]{}
+)