1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
|
use std::{fmt, sync::Arc};
use tokio::sync::mpsc;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
use super::ReflectionServiceState;
use crate::pb::v1::server_reflection_request::MessageRequest;
use crate::pb::v1::server_reflection_response::MessageResponse;
pub use crate::pb::v1::server_reflection_server::{ServerReflection, ServerReflectionServer};
use crate::pb::v1::{
ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
ServerReflectionResponse, ServiceResponse,
};
/// An implementation for `ServerReflection`.
#[derive(Debug)]
pub struct ReflectionService {
state: Arc<ReflectionServiceState>,
}
#[tonic::async_trait]
impl ServerReflection for ReflectionService {
type ServerReflectionInfoStream = ServerReflectionInfoStream;
async fn server_reflection_info(
&self,
req: Request<Streaming<ServerReflectionRequest>>,
) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
let mut req_rx = req.into_inner();
let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
let state = self.state.clone();
tokio::spawn(async move {
while let Some(req) = req_rx.next().await {
let Ok(req) = req else {
return;
};
let resp_msg = match req.message_request.clone() {
None => Err(Status::invalid_argument("invalid MessageRequest")),
Some(msg) => match msg {
MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: vec![fd],
})
}),
MessageRequest::FileContainingSymbol(s) => {
state.symbol_by_name(&s).map(|fd| {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: vec![fd],
})
})
}
MessageRequest::FileContainingExtension(_) => {
Err(Status::not_found("extensions are not supported"))
}
MessageRequest::AllExtensionNumbersOfType(_) => {
// NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail.
// https://github.com/hyperium/tonic/issues/1077
Ok(MessageResponse::AllExtensionNumbersResponse(
ExtensionNumberResponse::default(),
))
}
MessageRequest::ListServices(_) => {
Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
service: state
.list_services()
.iter()
.map(|s| ServiceResponse { name: s.clone() })
.collect(),
}))
}
},
};
match resp_msg {
Ok(resp_msg) => {
let resp = ServerReflectionResponse {
valid_host: req.host.clone(),
original_request: Some(req.clone()),
message_response: Some(resp_msg),
};
resp_tx.send(Ok(resp)).await.expect("send");
}
Err(status) => {
resp_tx.send(Err(status)).await.expect("send");
return;
}
}
}
});
Ok(Response::new(ServerReflectionInfoStream::new(resp_rx)))
}
}
impl From<ReflectionServiceState> for ReflectionService {
fn from(state: ReflectionServiceState) -> Self {
Self {
state: Arc::new(state),
}
}
}
/// A response stream.
pub struct ServerReflectionInfoStream {
inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>,
}
impl ServerReflectionInfoStream {
fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self {
let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx);
Self { inner }
}
}
impl Stream for ServerReflectionInfoStream {
type Item = Result<ServerReflectionResponse, Status>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.inner).poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl fmt::Debug for ServerReflectionInfoStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("ServerReflectionInfoStream").finish()
}
}
|