summaryrefslogtreecommitdiff
path: root/vendor/prost-build/src/message_graph.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/prost-build/src/message_graph.rs')
-rw-r--r--vendor/prost-build/src/message_graph.rs24
1 files changed, 16 insertions, 8 deletions
diff --git a/vendor/prost-build/src/message_graph.rs b/vendor/prost-build/src/message_graph.rs
index ac0ad152..7d43aece 100644
--- a/vendor/prost-build/src/message_graph.rs
+++ b/vendor/prost-build/src/message_graph.rs
@@ -4,7 +4,10 @@ use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;
-use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
+use prost_types::{
+ field_descriptor_proto::{Label, Type},
+ DescriptorProto, FileDescriptorProto,
+};
/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
@@ -12,15 +15,15 @@ use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
+ messages: HashMap<String, DescriptorProto>,
}
impl MessageGraph {
- pub fn new<'a>(
- files: impl Iterator<Item = &'a FileDescriptorProto>,
- ) -> Result<MessageGraph, String> {
+ pub(crate) fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> MessageGraph {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
+ messages: HashMap::new(),
};
for file in files {
@@ -34,13 +37,14 @@ impl MessageGraph {
}
}
- Ok(msg_graph)
+ msg_graph
}
fn get_or_insert_index(&mut self, msg_name: String) -> NodeIndex {
let MessageGraph {
ref mut index,
ref mut graph,
+ ..
} = *self;
assert_eq!(b'.', msg_name.as_bytes()[0]);
*index
@@ -58,19 +62,23 @@ impl MessageGraph {
let msg_index = self.get_or_insert_index(msg_name.clone());
for field in &msg.field {
- if field.r#type() == field_descriptor_proto::Type::Message
- && field.label() != field_descriptor_proto::Label::Repeated
- {
+ if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
}
}
+ self.messages.insert(msg_name.clone(), msg.clone());
for msg in &msg.nested_type {
self.add_message(&msg_name, msg);
}
}
+ /// Try get a message descriptor from current message graph
+ pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> {
+ self.messages.get(message)
+ }
+
/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {