diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/middleware')
7 files changed, 616 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go new file mode 100644 index 0000000..8f75a19 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/consistency.go @@ -0,0 +1,280 @@ +package consistency + +import ( + "context" + "errors" + "fmt" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +var ConsistencyCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "middleware", + Name: "consistency_assigned_total", + Help: "Count of the consistencies used per request", +}, []string{"method", "source", "service"}) + +type hasConsistency interface{ GetConsistency() *v1.Consistency } + +type hasOptionalCursor interface{ GetOptionalCursor() *v1.Cursor } + +type ctxKeyType struct{} + +var revisionKey ctxKeyType = struct{}{} + +var errInvalidZedToken = errors.New("invalid revision requested") + +type revisionHandle struct { + revision datastore.Revision +} + +// ContextWithHandle adds a placeholder to a context that will later be +// filled by the revision +func ContextWithHandle(ctx context.Context) context.Context { + return context.WithValue(ctx, revisionKey, &revisionHandle{}) +} + +// RevisionFromContext reads the selected revision out of a context.Context, computes a zedtoken +// from it, and returns an error if it has not been set on the context. +func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken, error) { + if c := ctx.Value(revisionKey); c != nil { + handle := c.(*revisionHandle) + rev := handle.revision + if rev != nil { + return rev, zedtoken.MustNewFromRevision(rev), nil + } + } + + return nil, nil, fmt.Errorf("consistency middleware did not inject revision") +} + +// AddRevisionToContext adds a revision to the given context, based on the consistency block found +// in the given request (if applicable). +func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error { + switch req := req.(type) { + case hasConsistency: + return addRevisionToContextFromConsistency(ctx, req, ds, serviceLabel) + default: + return nil + } +} + +// addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found +// in the given request (if applicable). +func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore, serviceLabel string) error { + handle := ctx.Value(revisionKey) + if handle == nil { + return nil + } + + var revision datastore.Revision + consistency := req.GetConsistency() + + withOptionalCursor, hasOptionalCursor := req.(hasOptionalCursor) + + switch { + case hasOptionalCursor && withOptionalCursor.GetOptionalCursor() != nil: + // Always use the revision encoded in the cursor. + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("snapshot", "cursor", serviceLabel).Inc() + } + + requestedRev, err := cursor.DecodeToDispatchRevision(withOptionalCursor.GetOptionalCursor(), ds) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + + err = ds.CheckRevision(ctx, requestedRev) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + + revision = requestedRev + + case consistency == nil || consistency.GetMinimizeLatency(): + // Minimize Latency: Use the datastore's current revision, whatever it may be. + source := "request" + if consistency == nil { + source = "server" + } + + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("minlatency", source, serviceLabel).Inc() + } + + databaseRev, err := ds.OptimizedRevision(ctx) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + revision = databaseRev + + case consistency.GetFullyConsistent(): + // Fully Consistent: Use the datastore's synchronized revision. + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() + } + + databaseRev, err := ds.HeadRevision(ctx) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + revision = databaseRev + + case consistency.GetAtLeastAsFresh() != nil: + // At least as fresh as: Pick one of the datastore's revision and that specified, which + // ever is later. + picked, pickedRequest, err := pickBestRevision(ctx, consistency.GetAtLeastAsFresh(), ds) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + + source := "server" + if pickedRequest { + source = "request" + } + + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("atleast", source, serviceLabel).Inc() + } + + revision = picked + + case consistency.GetAtExactSnapshot() != nil: + // Exact snapshot: Use the revision as encoded in the zed token. + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("snapshot", "request", serviceLabel).Inc() + } + + requestedRev, err := zedtoken.DecodeRevision(consistency.GetAtExactSnapshot(), ds) + if err != nil { + return errInvalidZedToken + } + + err = ds.CheckRevision(ctx, requestedRev) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + + revision = requestedRev + + default: + return fmt.Errorf("missing handling of consistency case in %v", consistency) + } + + handle.(*revisionHandle).revision = revision + return nil +} + +var bypassServiceWhitelist = map[string]struct{}{ + "/grpc.reflection.v1alpha.ServerReflection/": {}, + "/grpc.reflection.v1.ServerReflection/": {}, + "/grpc.health.v1.Health/": {}, +} + +// UnaryServerInterceptor returns a new unary server interceptor that performs per-request exchange of +// the specified consistency configuration for the revision at which to perform the request. +func UnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + for bypass := range bypassServiceWhitelist { + if strings.HasPrefix(info.FullMethod, bypass) { + return handler(ctx, req) + } + } + ds := datastoremw.MustFromContext(ctx) + newCtx := ContextWithHandle(ctx) + if err := AddRevisionToContext(newCtx, req, ds, serviceLabel); err != nil { + return nil, err + } + + return handler(newCtx, req) + } +} + +// StreamServerInterceptor returns a new stream server interceptor that performs per-request exchange of +// the specified consistency configuration for the revision at which to perform the request. +func StreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + for bypass := range bypassServiceWhitelist { + if strings.HasPrefix(info.FullMethod, bypass) { + return handler(srv, stream) + } + } + wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, AddRevisionToContext} + return handler(srv, wrapper) + } +} + +type recvWrapper struct { + grpc.ServerStream + ctx context.Context + serviceLabel string + handler func(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error +} + +func (s *recvWrapper) Context() context.Context { return s.ctx } + +func (s *recvWrapper) RecvMsg(m interface{}) error { + if err := s.ServerStream.RecvMsg(m); err != nil { + return err + } + ds := datastoremw.MustFromContext(s.ctx) + return s.handler(s.ctx, m, ds, s.serviceLabel) +} + +// pickBestRevision compares the provided ZedToken with the optimized revision of the datastore, and returns the most +// recent one. The boolean return value will be true if the provided ZedToken is the most recent, false otherwise. +func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore.Datastore) (datastore.Revision, bool, error) { + // Calculate a revision as we see fit + databaseRev, err := ds.OptimizedRevision(ctx) + if err != nil { + return datastore.NoRevision, false, err + } + + if requested != nil { + requestedRev, err := zedtoken.DecodeRevision(requested, ds) + if err != nil { + return datastore.NoRevision, false, errInvalidZedToken + } + + if databaseRev.GreaterThan(requestedRev) { + return databaseRev, false, nil + } + + return requestedRev, true, nil + } + + return databaseRev, false, nil +} + +func rewriteDatastoreError(ctx context.Context, err error) error { + // Check if the error can be directly used. + if _, ok := status.FromError(err); ok { + return err + } + + switch { + case errors.As(err, &datastore.InvalidRevisionError{}): + return status.Errorf(codes.OutOfRange, "invalid revision: %s", err) + + case errors.As(err, &datastore.ReadOnlyError{}): + return shared.ErrServiceReadOnly + + default: + log.Ctx(ctx).Err(err).Msg("unexpected consistency middleware error") + return err + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go new file mode 100644 index 0000000..593ec0a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/doc.go @@ -0,0 +1,2 @@ +// Package consistency defines middleware to set, based on the request's consistency level, the right datastore revision to use. +package consistency diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go new file mode 100644 index 0000000..0ec88f7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/consistency/forcefull.go @@ -0,0 +1,66 @@ +package consistency + +import ( + "context" + "strings" + + "google.golang.org/grpc" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/pkg/datastore" +) + +// ForceFullConsistencyUnaryServerInterceptor returns a new unary server interceptor that enforces full consistency +// for all requests, except for those in the bypassServiceWhitelist. +func ForceFullConsistencyUnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + for bypass := range bypassServiceWhitelist { + if strings.HasPrefix(info.FullMethod, bypass) { + return handler(ctx, req) + } + } + ds := datastoremw.MustFromContext(ctx) + newCtx := ContextWithHandle(ctx) + if err := setFullConsistencyRevisionToContext(newCtx, req, ds, serviceLabel); err != nil { + return nil, err + } + + return handler(newCtx, req) + } +} + +// ForceFullConsistencyStreamServerInterceptor returns a new stream server interceptor that enforces full consistency +// for all requests, except for those in the bypassServiceWhitelist. +func ForceFullConsistencyStreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + for bypass := range bypassServiceWhitelist { + if strings.HasPrefix(info.FullMethod, bypass) { + return handler(srv, stream) + } + } + wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, setFullConsistencyRevisionToContext} + return handler(srv, wrapper) + } +} + +func setFullConsistencyRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error { + handle := ctx.Value(revisionKey) + if handle == nil { + return nil + } + + switch req.(type) { + case hasConsistency: + if serviceLabel != "" { + ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() + } + + databaseRev, err := ds.HeadRevision(ctx) + if err != nil { + return rewriteDatastoreError(ctx, err) + } + handle.(*revisionHandle).revision = databaseRev + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/doc.go b/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/doc.go new file mode 100644 index 0000000..571288c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/doc.go @@ -0,0 +1,2 @@ +// Package nodeid defines middleware to update the context with the Id of the SpiceDB node running the request. +package nodeid diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/nodeid.go b/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/nodeid.go new file mode 100644 index 0000000..3885036 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/nodeid/nodeid.go @@ -0,0 +1,102 @@ +package nodeid + +import ( + "context" + "fmt" + "os" + + "github.com/cespare/xxhash/v2" + middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" + "github.com/rs/zerolog/log" + "google.golang.org/grpc" +) + +const spiceDBPrefix = "spicedb:" + +type ctxKeyType struct{} + +var nodeIDKey ctxKeyType = struct{}{} + +type nodeIDHandle struct { + nodeID string +} + +var defaultNodeID string + +func init() { + hostname, err := os.Hostname() + if err != nil { + log.Warn().Err(err).Msg("failed to get hostname, using an empty node ID") + return + } + + // Hash the hostname to get the final default node ID. + hasher := xxhash.New() + if _, err := hasher.WriteString(hostname); err != nil { + log.Warn().Err(err).Msg("failed to hash hostname, using an empty node ID") + return + } + + defaultNodeID = spiceDBPrefix + fmt.Sprintf("%x", hasher.Sum(nil)) +} + +// ContextWithHandle adds a placeholder to a context that will later be +// filled by the Node ID. +func ContextWithHandle(ctx context.Context) context.Context { + return context.WithValue(ctx, nodeIDKey, &nodeIDHandle{}) +} + +// FromContext reads the node's ID out of a context.Context. +func FromContext(ctx context.Context) (string, error) { + if c := ctx.Value(nodeIDKey); c != nil { + handle := c.(*nodeIDHandle) + if handle.nodeID != "" { + return handle.nodeID, nil + } + } + + if err := setInContext(ctx, defaultNodeID); err != nil { + return "", err + } + + return defaultNodeID, nil +} + +// setInContext adds a node ID to the given context +func setInContext(ctx context.Context, nodeID string) error { + handle := ctx.Value(nodeIDKey) + if handle == nil { + return nil + } + handle.(*nodeIDHandle).nodeID = nodeID + return nil +} + +// UnaryServerInterceptor returns a new unary server interceptor that adds the +// node ID to the context. If empty, spicedb:$hostname is used. +func UnaryServerInterceptor(nodeID string) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + newCtx := ContextWithHandle(ctx) + if nodeID != "" { + if err := setInContext(newCtx, nodeID); err != nil { + return nil, err + } + } + return handler(newCtx, req) + } +} + +// StreamServerInterceptor returns a new stream server interceptor that adds the +// node ID to the context. If empty, spicedb:$hostname is used. +func StreamServerInterceptor(nodeID string) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + wrapped := middleware.WrapServerStream(stream) + wrapped.WrappedContext = ContextWithHandle(wrapped.WrappedContext) + if nodeID != "" { + if err := setInContext(wrapped.WrappedContext, nodeID); err != nil { + return err + } + } + return handler(srv, wrapped) + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/doc.go b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/doc.go new file mode 100644 index 0000000..2f47d2f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/doc.go @@ -0,0 +1,2 @@ +// Package requestid defines middleware to set a request or response header with a request ID. +package requestid diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go new file mode 100644 index 0000000..7a0904a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go @@ -0,0 +1,162 @@ +package requestid + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" + + "github.com/authzed/authzed-go/pkg/requestmeta" + "github.com/authzed/authzed-go/pkg/responsemeta" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "github.com/rs/xid" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +const metadataKey = string(requestmeta.RequestIDKey) + +// Option instances control how the middleware is initialized. +type Option func(*handleRequestID) + +// GenerateIfMissing will instruct the middleware to create a request ID if one +// isn't already on the incoming request. +// +// default: false +func GenerateIfMissing(enable bool) Option { + return func(reporter *handleRequestID) { + reporter.generateIfMissing = enable + } +} + +// IDGenerator functions are used to generate request IDs if a new one is needed. +type IDGenerator func() string + +// GenerateRequestID generates a new request ID. +func GenerateRequestID() string { + return xid.New().String() +} + +type handleRequestID struct { + generateIfMissing bool + requestIDGenerator IDGenerator +} + +func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) + + if haveRequestID { + ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{ + requestmeta.RequestIDKey: requestID, + }) + } + + return interceptors.NoopReporter{}, ctx +} + +func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) + + if haveRequestID { + err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{ + responsemeta.RequestID: requestID, + }) + // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite + // this prevents logging unnecessary error messages + if ctx.Err() != nil { + return interceptors.NoopReporter{}, ctx + } + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata") + } + } + + return interceptors.NoopReporter{}, ctx +} + +func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) { + haveRequestID, requestID, md := fromContext(ctx) + + if !haveRequestID && r.generateIfMissing { + requestID = r.requestIDGenerator() + haveRequestID = true + + // Inject the newly generated request ID into the metadata + if md == nil { + md = metadata.New(nil) + } + + md.Set(metadataKey, requestID) + ctx = metadata.NewIncomingContext(ctx, md) + } + + return haveRequestID, requestID, ctx +} + +func fromContext(ctx context.Context) (bool, string, metadata.MD) { + var requestID string + var haveRequestID bool + md, ok := metadata.FromIncomingContext(ctx) + if ok { + var requestIDs []string + requestIDs, haveRequestID = md[metadataKey] + if haveRequestID { + requestID = requestIDs[0] + } + } + + return haveRequestID, requestID, md +} + +// PropagateIfExists copies the request ID from the source context to the target context if it exists. +// The updated target context is returned. +func PropagateIfExists(source, target context.Context) context.Context { + exists, requestID, _ := fromContext(source) + + if exists { + targetMD, _ := metadata.FromIncomingContext(target) + if targetMD == nil { + targetMD = metadata.New(nil) + } + + targetMD.Set(metadataKey, requestID) + return metadata.NewIncomingContext(target, targetMD) + } + + return target +} + +// UnaryServerInterceptor returns a new interceptor which handles server request IDs according +// to the provided options. +func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { + return interceptors.UnaryServerInterceptor(createReporter(opts)) +} + +// StreamServerInterceptor returns a new interceptor which handles server request IDs according +// to the provided options. +func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { + return interceptors.StreamServerInterceptor(createReporter(opts)) +} + +// UnaryClientInterceptor returns a new interceptor which handles client request IDs according +// to the provided options. +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + return interceptors.UnaryClientInterceptor(createReporter(opts)) +} + +// StreamClientInterceptor returns a new interceptor which handles client requestIDs according +// to the provided options. +func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { + return interceptors.StreamClientInterceptor(createReporter(opts)) +} + +func createReporter(opts []Option) *handleRequestID { + reporter := &handleRequestID{ + requestIDGenerator: GenerateRequestID, + } + + for _, opt := range opts { + opt(reporter) + } + + return reporter +} |
