diff options
Diffstat (limited to 'vendor/prost-build/src/message_graph.rs')
| -rw-r--r-- | vendor/prost-build/src/message_graph.rs | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/vendor/prost-build/src/message_graph.rs b/vendor/prost-build/src/message_graph.rs index ac0ad152..7d43aece 100644 --- a/vendor/prost-build/src/message_graph.rs +++ b/vendor/prost-build/src/message_graph.rs @@ -4,7 +4,10 @@ use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; use petgraph::Graph; -use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; +use prost_types::{ + field_descriptor_proto::{Label, Type}, + DescriptorProto, FileDescriptorProto, +}; /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so @@ -12,15 +15,15 @@ use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; pub struct MessageGraph { index: HashMap<String, NodeIndex>, graph: Graph<String, ()>, + messages: HashMap<String, DescriptorProto>, } impl MessageGraph { - pub fn new<'a>( - files: impl Iterator<Item = &'a FileDescriptorProto>, - ) -> Result<MessageGraph, String> { + pub(crate) fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), + messages: HashMap::new(), }; for file in files { @@ -34,13 +37,14 @@ impl MessageGraph { } } - Ok(msg_graph) + msg_graph } fn get_or_insert_index(&mut self, msg_name: String) -> NodeIndex { let MessageGraph { ref mut index, ref mut graph, + .. } = *self; assert_eq!(b'.', msg_name.as_bytes()[0]); *index @@ -58,19 +62,23 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == field_descriptor_proto::Type::Message - && field.label() != field_descriptor_proto::Label::Repeated - { + if field.r#type() == Type::Message && field.label() != Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); } } + self.messages.insert(msg_name.clone(), msg.clone()); for msg in &msg.nested_type { self.add_message(&msg_name, msg); } } + /// Try get a message descriptor from current message graph + pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> { + self.messages.get(message) + } + /// Returns true if message type `inner` is nested in message type `outer`. pub fn is_nested(&self, outer: &str, inner: &str) -> bool { let outer = match self.index.get(outer) { |
