diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go')
| -rw-r--r-- | vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go new file mode 100644 index 0000000..7a0904a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go @@ -0,0 +1,162 @@ +package requestid + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" + + "github.com/authzed/authzed-go/pkg/requestmeta" + "github.com/authzed/authzed-go/pkg/responsemeta" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "github.com/rs/xid" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +const metadataKey = string(requestmeta.RequestIDKey) + +// Option instances control how the middleware is initialized. +type Option func(*handleRequestID) + +// GenerateIfMissing will instruct the middleware to create a request ID if one +// isn't already on the incoming request. +// +// default: false +func GenerateIfMissing(enable bool) Option { + return func(reporter *handleRequestID) { + reporter.generateIfMissing = enable + } +} + +// IDGenerator functions are used to generate request IDs if a new one is needed. +type IDGenerator func() string + +// GenerateRequestID generates a new request ID. +func GenerateRequestID() string { + return xid.New().String() +} + +type handleRequestID struct { + generateIfMissing bool + requestIDGenerator IDGenerator +} + +func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) + + if haveRequestID { + ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{ + requestmeta.RequestIDKey: requestID, + }) + } + + return interceptors.NoopReporter{}, ctx +} + +func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) { + haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) + + if haveRequestID { + err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{ + responsemeta.RequestID: requestID, + }) + // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite + // this prevents logging unnecessary error messages + if ctx.Err() != nil { + return interceptors.NoopReporter{}, ctx + } + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata") + } + } + + return interceptors.NoopReporter{}, ctx +} + +func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) { + haveRequestID, requestID, md := fromContext(ctx) + + if !haveRequestID && r.generateIfMissing { + requestID = r.requestIDGenerator() + haveRequestID = true + + // Inject the newly generated request ID into the metadata + if md == nil { + md = metadata.New(nil) + } + + md.Set(metadataKey, requestID) + ctx = metadata.NewIncomingContext(ctx, md) + } + + return haveRequestID, requestID, ctx +} + +func fromContext(ctx context.Context) (bool, string, metadata.MD) { + var requestID string + var haveRequestID bool + md, ok := metadata.FromIncomingContext(ctx) + if ok { + var requestIDs []string + requestIDs, haveRequestID = md[metadataKey] + if haveRequestID { + requestID = requestIDs[0] + } + } + + return haveRequestID, requestID, md +} + +// PropagateIfExists copies the request ID from the source context to the target context if it exists. +// The updated target context is returned. +func PropagateIfExists(source, target context.Context) context.Context { + exists, requestID, _ := fromContext(source) + + if exists { + targetMD, _ := metadata.FromIncomingContext(target) + if targetMD == nil { + targetMD = metadata.New(nil) + } + + targetMD.Set(metadataKey, requestID) + return metadata.NewIncomingContext(target, targetMD) + } + + return target +} + +// UnaryServerInterceptor returns a new interceptor which handles server request IDs according +// to the provided options. +func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { + return interceptors.UnaryServerInterceptor(createReporter(opts)) +} + +// StreamServerInterceptor returns a new interceptor which handles server request IDs according +// to the provided options. +func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { + return interceptors.StreamServerInterceptor(createReporter(opts)) +} + +// UnaryClientInterceptor returns a new interceptor which handles client request IDs according +// to the provided options. +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + return interceptors.UnaryClientInterceptor(createReporter(opts)) +} + +// StreamClientInterceptor returns a new interceptor which handles client requestIDs according +// to the provided options. +func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { + return interceptors.StreamClientInterceptor(createReporter(opts)) +} + +func createReporter(opts []Option) *handleRequestID { + reporter := &handleRequestID{ + requestIDGenerator: GenerateRequestID, + } + + for _, opt := range opts { + opt(reporter) + } + + return reporter +} |
