summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/middleware/requestid/requestid.go
blob: 7a0904ab9fb1c3ceeee74f8025109f2faedf5202 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package requestid

import (
	"context"

	log "github.com/authzed/spicedb/internal/logging"

	"github.com/authzed/authzed-go/pkg/requestmeta"
	"github.com/authzed/authzed-go/pkg/responsemeta"
	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
	"github.com/rs/xid"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

const metadataKey = string(requestmeta.RequestIDKey)

// Option instances control how the middleware is initialized.
type Option func(*handleRequestID)

// GenerateIfMissing will instruct the middleware to create a request ID if one
// isn't already on the incoming request.
//
// default: false
func GenerateIfMissing(enable bool) Option {
	return func(reporter *handleRequestID) {
		reporter.generateIfMissing = enable
	}
}

// IDGenerator functions are used to generate request IDs if a new one is needed.
type IDGenerator func() string

// GenerateRequestID generates a new request ID.
func GenerateRequestID() string {
	return xid.New().String()
}

type handleRequestID struct {
	generateIfMissing  bool
	requestIDGenerator IDGenerator
}

func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) {
	haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)

	if haveRequestID {
		ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{
			requestmeta.RequestIDKey: requestID,
		})
	}

	return interceptors.NoopReporter{}, ctx
}

func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) {
	haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)

	if haveRequestID {
		err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{
			responsemeta.RequestID: requestID,
		})
		// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
		// this prevents logging unnecessary error messages
		if ctx.Err() != nil {
			return interceptors.NoopReporter{}, ctx
		}
		if err != nil {
			log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata")
		}
	}

	return interceptors.NoopReporter{}, ctx
}

func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) {
	haveRequestID, requestID, md := fromContext(ctx)

	if !haveRequestID && r.generateIfMissing {
		requestID = r.requestIDGenerator()
		haveRequestID = true

		// Inject the newly generated request ID into the metadata
		if md == nil {
			md = metadata.New(nil)
		}

		md.Set(metadataKey, requestID)
		ctx = metadata.NewIncomingContext(ctx, md)
	}

	return haveRequestID, requestID, ctx
}

func fromContext(ctx context.Context) (bool, string, metadata.MD) {
	var requestID string
	var haveRequestID bool
	md, ok := metadata.FromIncomingContext(ctx)
	if ok {
		var requestIDs []string
		requestIDs, haveRequestID = md[metadataKey]
		if haveRequestID {
			requestID = requestIDs[0]
		}
	}

	return haveRequestID, requestID, md
}

// PropagateIfExists copies the request ID from the source context to the target context if it exists.
// The updated target context is returned.
func PropagateIfExists(source, target context.Context) context.Context {
	exists, requestID, _ := fromContext(source)

	if exists {
		targetMD, _ := metadata.FromIncomingContext(target)
		if targetMD == nil {
			targetMD = metadata.New(nil)
		}

		targetMD.Set(metadataKey, requestID)
		return metadata.NewIncomingContext(target, targetMD)
	}

	return target
}

// UnaryServerInterceptor returns a new interceptor which handles server request IDs according
// to the provided options.
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
	return interceptors.UnaryServerInterceptor(createReporter(opts))
}

// StreamServerInterceptor returns a new interceptor which handles server request IDs according
// to the provided options.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
	return interceptors.StreamServerInterceptor(createReporter(opts))
}

// UnaryClientInterceptor returns a new interceptor which handles client request IDs according
// to the provided options.
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
	return interceptors.UnaryClientInterceptor(createReporter(opts))
}

// StreamClientInterceptor returns a new interceptor which handles client requestIDs according
// to the provided options.
func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor {
	return interceptors.StreamClientInterceptor(createReporter(opts))
}

func createReporter(opts []Option) *handleRequestID {
	reporter := &handleRequestID{
		requestIDGenerator: GenerateRequestID,
	}

	for _, opt := range opts {
		opt(reporter)
	}

	return reporter
}