summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go')
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go162
1 files changed, 162 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go
new file mode 100644
index 0000000..7a0904a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go
@@ -0,0 +1,162 @@
+package requestid
+
+import (
+ "context"
+
+ log "github.com/authzed/spicedb/internal/logging"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
+ "github.com/rs/xid"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+const metadataKey = string(requestmeta.RequestIDKey)
+
+// Option instances control how the middleware is initialized.
+type Option func(*handleRequestID)
+
+// GenerateIfMissing will instruct the middleware to create a request ID if one
+// isn't already on the incoming request.
+//
+// default: false
+func GenerateIfMissing(enable bool) Option {
+ return func(reporter *handleRequestID) {
+ reporter.generateIfMissing = enable
+ }
+}
+
+// IDGenerator functions are used to generate request IDs if a new one is needed.
+type IDGenerator func() string
+
+// GenerateRequestID generates a new request ID.
+func GenerateRequestID() string {
+ return xid.New().String()
+}
+
+type handleRequestID struct {
+ generateIfMissing bool
+ requestIDGenerator IDGenerator
+}
+
+func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) {
+ haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)
+
+ if haveRequestID {
+ ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{
+ requestmeta.RequestIDKey: requestID,
+ })
+ }
+
+ return interceptors.NoopReporter{}, ctx
+}
+
+func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) {
+ haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)
+
+ if haveRequestID {
+ err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{
+ responsemeta.RequestID: requestID,
+ })
+ // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
+ // this prevents logging unnecessary error messages
+ if ctx.Err() != nil {
+ return interceptors.NoopReporter{}, ctx
+ }
+ if err != nil {
+ log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata")
+ }
+ }
+
+ return interceptors.NoopReporter{}, ctx
+}
+
+func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) {
+ haveRequestID, requestID, md := fromContext(ctx)
+
+ if !haveRequestID && r.generateIfMissing {
+ requestID = r.requestIDGenerator()
+ haveRequestID = true
+
+ // Inject the newly generated request ID into the metadata
+ if md == nil {
+ md = metadata.New(nil)
+ }
+
+ md.Set(metadataKey, requestID)
+ ctx = metadata.NewIncomingContext(ctx, md)
+ }
+
+ return haveRequestID, requestID, ctx
+}
+
+func fromContext(ctx context.Context) (bool, string, metadata.MD) {
+ var requestID string
+ var haveRequestID bool
+ md, ok := metadata.FromIncomingContext(ctx)
+ if ok {
+ var requestIDs []string
+ requestIDs, haveRequestID = md[metadataKey]
+ if haveRequestID {
+ requestID = requestIDs[0]
+ }
+ }
+
+ return haveRequestID, requestID, md
+}
+
+// PropagateIfExists copies the request ID from the source context to the target context if it exists.
+// The updated target context is returned.
+func PropagateIfExists(source, target context.Context) context.Context {
+ exists, requestID, _ := fromContext(source)
+
+ if exists {
+ targetMD, _ := metadata.FromIncomingContext(target)
+ if targetMD == nil {
+ targetMD = metadata.New(nil)
+ }
+
+ targetMD.Set(metadataKey, requestID)
+ return metadata.NewIncomingContext(target, targetMD)
+ }
+
+ return target
+}
+
+// UnaryServerInterceptor returns a new interceptor which handles server request IDs according
+// to the provided options.
+func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
+ return interceptors.UnaryServerInterceptor(createReporter(opts))
+}
+
+// StreamServerInterceptor returns a new interceptor which handles server request IDs according
+// to the provided options.
+func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
+ return interceptors.StreamServerInterceptor(createReporter(opts))
+}
+
+// UnaryClientInterceptor returns a new interceptor which handles client request IDs according
+// to the provided options.
+func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
+ return interceptors.UnaryClientInterceptor(createReporter(opts))
+}
+
+// StreamClientInterceptor returns a new interceptor which handles client requestIDs according
+// to the provided options.
+func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor {
+ return interceptors.StreamClientInterceptor(createReporter(opts))
+}
+
+func createReporter(opts []Option) *handleRequestID {
+ reporter := &handleRequestID{
+ requestIDGenerator: GenerateRequestID,
+ }
+
+ for _, opt := range opts {
+ opt(reporter)
+ }
+
+ return reporter
+}