summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/middleware/consistency
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/pkg/middleware/consistency
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/middleware/consistency')
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go280
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go66
3 files changed, 348 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go
new file mode 100644
index 0000000..8f75a19
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go
@@ -0,0 +1,280 @@
+package consistency
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/services/shared"
+ "github.com/authzed/spicedb/pkg/cursor"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+var ConsistencyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "middleware",
+ Name: "consistency_assigned_total",
+ Help: "Count of the consistencies used per request",
+}, []string{"method", "source", "service"})
+
+type hasConsistency interface{ GetConsistency() *v1.Consistency }
+
+type hasOptionalCursor interface{ GetOptionalCursor() *v1.Cursor }
+
+type ctxKeyType struct{}
+
+var revisionKey ctxKeyType = struct{}{}
+
+var errInvalidZedToken = errors.New("invalid revision requested")
+
+type revisionHandle struct {
+ revision datastore.Revision
+}
+
+// ContextWithHandle adds a placeholder to a context that will later be
+// filled by the revision
+func ContextWithHandle(ctx context.Context) context.Context {
+ return context.WithValue(ctx, revisionKey, &revisionHandle{})
+}
+
+// RevisionFromContext reads the selected revision out of a context.Context, computes a zedtoken
+// from it, and returns an error if it has not been set on the context.
+func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken, error) {
+ if c := ctx.Value(revisionKey); c != nil {
+ handle := c.(*revisionHandle)
+ rev := handle.revision
+ if rev != nil {
+ return rev, zedtoken.MustNewFromRevision(rev), nil
+ }
+ }
+
+ return nil, nil, fmt.Errorf("consistency middleware did not inject revision")
+}
+
+// AddRevisionToContext adds a revision to the given context, based on the consistency block found
+// in the given request (if applicable).
+func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error {
+ switch req := req.(type) {
+ case hasConsistency:
+ return addRevisionToContextFromConsistency(ctx, req, ds, serviceLabel)
+ default:
+ return nil
+ }
+}
+
+// addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found
+// in the given request (if applicable).
+func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore, serviceLabel string) error {
+ handle := ctx.Value(revisionKey)
+ if handle == nil {
+ return nil
+ }
+
+ var revision datastore.Revision
+ consistency := req.GetConsistency()
+
+ withOptionalCursor, hasOptionalCursor := req.(hasOptionalCursor)
+
+ switch {
+ case hasOptionalCursor && withOptionalCursor.GetOptionalCursor() != nil:
+ // Always use the revision encoded in the cursor.
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("snapshot", "cursor", serviceLabel).Inc()
+ }
+
+ requestedRev, err := cursor.DecodeToDispatchRevision(withOptionalCursor.GetOptionalCursor(), ds)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+
+ err = ds.CheckRevision(ctx, requestedRev)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+
+ revision = requestedRev
+
+ case consistency == nil || consistency.GetMinimizeLatency():
+ // Minimize Latency: Use the datastore's current revision, whatever it may be.
+ source := "request"
+ if consistency == nil {
+ source = "server"
+ }
+
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("minlatency", source, serviceLabel).Inc()
+ }
+
+ databaseRev, err := ds.OptimizedRevision(ctx)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+ revision = databaseRev
+
+ case consistency.GetFullyConsistent():
+ // Fully Consistent: Use the datastore's synchronized revision.
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc()
+ }
+
+ databaseRev, err := ds.HeadRevision(ctx)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+ revision = databaseRev
+
+ case consistency.GetAtLeastAsFresh() != nil:
+ // At least as fresh as: Pick one of the datastore's revision and that specified, which
+ // ever is later.
+ picked, pickedRequest, err := pickBestRevision(ctx, consistency.GetAtLeastAsFresh(), ds)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+
+ source := "server"
+ if pickedRequest {
+ source = "request"
+ }
+
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("atleast", source, serviceLabel).Inc()
+ }
+
+ revision = picked
+
+ case consistency.GetAtExactSnapshot() != nil:
+ // Exact snapshot: Use the revision as encoded in the zed token.
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("snapshot", "request", serviceLabel).Inc()
+ }
+
+ requestedRev, err := zedtoken.DecodeRevision(consistency.GetAtExactSnapshot(), ds)
+ if err != nil {
+ return errInvalidZedToken
+ }
+
+ err = ds.CheckRevision(ctx, requestedRev)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+
+ revision = requestedRev
+
+ default:
+ return fmt.Errorf("missing handling of consistency case in %v", consistency)
+ }
+
+ handle.(*revisionHandle).revision = revision
+ return nil
+}
+
+var bypassServiceWhitelist = map[string]struct{}{
+ "/grpc.reflection.v1alpha.ServerReflection/": {},
+ "/grpc.reflection.v1.ServerReflection/": {},
+ "/grpc.health.v1.Health/": {},
+}
+
+// UnaryServerInterceptor returns a new unary server interceptor that performs per-request exchange of
+// the specified consistency configuration for the revision at which to perform the request.
+func UnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ for bypass := range bypassServiceWhitelist {
+ if strings.HasPrefix(info.FullMethod, bypass) {
+ return handler(ctx, req)
+ }
+ }
+ ds := datastoremw.MustFromContext(ctx)
+ newCtx := ContextWithHandle(ctx)
+ if err := AddRevisionToContext(newCtx, req, ds, serviceLabel); err != nil {
+ return nil, err
+ }
+
+ return handler(newCtx, req)
+ }
+}
+
+// StreamServerInterceptor returns a new stream server interceptor that performs per-request exchange of
+// the specified consistency configuration for the revision at which to perform the request.
+func StreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor {
+ return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ for bypass := range bypassServiceWhitelist {
+ if strings.HasPrefix(info.FullMethod, bypass) {
+ return handler(srv, stream)
+ }
+ }
+ wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, AddRevisionToContext}
+ return handler(srv, wrapper)
+ }
+}
+
+type recvWrapper struct {
+ grpc.ServerStream
+ ctx context.Context
+ serviceLabel string
+ handler func(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error
+}
+
+func (s *recvWrapper) Context() context.Context { return s.ctx }
+
+func (s *recvWrapper) RecvMsg(m interface{}) error {
+ if err := s.ServerStream.RecvMsg(m); err != nil {
+ return err
+ }
+ ds := datastoremw.MustFromContext(s.ctx)
+ return s.handler(s.ctx, m, ds, s.serviceLabel)
+}
+
+// pickBestRevision compares the provided ZedToken with the optimized revision of the datastore, and returns the most
+// recent one. The boolean return value will be true if the provided ZedToken is the most recent, false otherwise.
+func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore.Datastore) (datastore.Revision, bool, error) {
+ // Calculate a revision as we see fit
+ databaseRev, err := ds.OptimizedRevision(ctx)
+ if err != nil {
+ return datastore.NoRevision, false, err
+ }
+
+ if requested != nil {
+ requestedRev, err := zedtoken.DecodeRevision(requested, ds)
+ if err != nil {
+ return datastore.NoRevision, false, errInvalidZedToken
+ }
+
+ if databaseRev.GreaterThan(requestedRev) {
+ return databaseRev, false, nil
+ }
+
+ return requestedRev, true, nil
+ }
+
+ return databaseRev, false, nil
+}
+
+func rewriteDatastoreError(ctx context.Context, err error) error {
+ // Check if the error can be directly used.
+ if _, ok := status.FromError(err); ok {
+ return err
+ }
+
+ switch {
+ case errors.As(err, &datastore.InvalidRevisionError{}):
+ return status.Errorf(codes.OutOfRange, "invalid revision: %s", err)
+
+ case errors.As(err, &datastore.ReadOnlyError{}):
+ return shared.ErrServiceReadOnly
+
+ default:
+ log.Ctx(ctx).Err(err).Msg("unexpected consistency middleware error")
+ return err
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go
new file mode 100644
index 0000000..593ec0a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go
@@ -0,0 +1,2 @@
+// Package consistency defines middleware to set, based on the request's consistency level, the right datastore revision to use.
+package consistency
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go
new file mode 100644
index 0000000..0ec88f7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go
@@ -0,0 +1,66 @@
+package consistency
+
+import (
+ "context"
+ "strings"
+
+ "google.golang.org/grpc"
+
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+// ForceFullConsistencyUnaryServerInterceptor returns a new unary server interceptor that enforces full consistency
+// for all requests, except for those in the bypassServiceWhitelist.
+func ForceFullConsistencyUnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ for bypass := range bypassServiceWhitelist {
+ if strings.HasPrefix(info.FullMethod, bypass) {
+ return handler(ctx, req)
+ }
+ }
+ ds := datastoremw.MustFromContext(ctx)
+ newCtx := ContextWithHandle(ctx)
+ if err := setFullConsistencyRevisionToContext(newCtx, req, ds, serviceLabel); err != nil {
+ return nil, err
+ }
+
+ return handler(newCtx, req)
+ }
+}
+
+// ForceFullConsistencyStreamServerInterceptor returns a new stream server interceptor that enforces full consistency
+// for all requests, except for those in the bypassServiceWhitelist.
+func ForceFullConsistencyStreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor {
+ return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ for bypass := range bypassServiceWhitelist {
+ if strings.HasPrefix(info.FullMethod, bypass) {
+ return handler(srv, stream)
+ }
+ }
+ wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, setFullConsistencyRevisionToContext}
+ return handler(srv, wrapper)
+ }
+}
+
+func setFullConsistencyRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error {
+ handle := ctx.Value(revisionKey)
+ if handle == nil {
+ return nil
+ }
+
+ switch req.(type) {
+ case hasConsistency:
+ if serviceLabel != "" {
+ ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc()
+ }
+
+ databaseRev, err := ds.HeadRevision(ctx)
+ if err != nil {
+ return rewriteDatastoreError(ctx, err)
+ }
+ handle.(*revisionHandle).revision = databaseRev
+ }
+
+ return nil
+}