diff --git a/src/generator/json_registry.rs b/src/generator/json_registry.rs index 52dd573..53adba3 100644 --- a/src/generator/json_registry.rs +++ b/src/generator/json_registry.rs @@ -1,5 +1,11 @@ +use crate::generator::utils::SourceContext; + use protobuf::{ - Message, UnknownValueRef, descriptor::DescriptorProto, descriptor::field_descriptor_proto::Type, + Message, UnknownValueRef, + descriptor::{ + DescriptorProto, FileDescriptorProto, + field_descriptor_proto::{Label, Type}, + }, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -52,12 +58,15 @@ impl From for WireType { /// Generate json registry from given messages pub fn generate_json_registry( + fd: &FileDescriptorProto, messages: &[DescriptorProto], package: &str, parent_path: Option<&str>, out: &mut ProtoJsonRegistry, ) { - for message in messages { + let ctx = SourceContext::new(fd); + + for (i, message) in messages.iter().enumerate() { if message.options.map_entry() { continue; } @@ -77,7 +86,7 @@ pub fn generate_json_registry( // iterate nested types if !message.nested_type.is_empty() { - generate_json_registry(&message.nested_type, package, Some(&scoped_name), out); + generate_json_registry(fd, &message.nested_type, package, Some(&scoped_name), out); } // collect xor constants @@ -91,17 +100,43 @@ pub fn generate_json_registry( None }; + let is_repeated = field.label() == Label::LABEL_REPEATED; + let is_scalar = matches!( + field.type_(), + Type::TYPE_INT32 + | Type::TYPE_INT64 + | Type::TYPE_UINT32 + | Type::TYPE_UINT64 + | Type::TYPE_SINT32 + | Type::TYPE_SINT64 + | Type::TYPE_BOOL + | Type::TYPE_ENUM + | Type::TYPE_FIXED32 + | Type::TYPE_SFIXED32 + | Type::TYPE_FLOAT + | Type::TYPE_FIXED64 + | Type::TYPE_SFIXED64 + | Type::TYPE_DOUBLE + ); + + let is_packed = + (is_repeated && is_scalar && !field.options.has_packed()) || field.options.packed(); + out.messages.entry(full_name.clone()).or_default().insert( field.name().to_string(), MessageField { field_number: field.number() as u32, - wire_type: field.type_().into(), + wire_type: if is_packed { + WireType::LengthDelimited + } else { + field.type_().into() + }, xor_const: xor_const_field, }, ); } - // collect command ids + // collect command ids from message extension if let Some(command_id_field) = message .options .unknown_fields() @@ -109,6 +144,40 @@ pub fn generate_json_registry( && let UnknownValueRef::Varint(varint) = command_id_field { out.command_ids.insert(full_name, varint as u16); + continue; }; + + // collect command ids from comment + if let Some(source_loc) = ctx.get_message_loc(i) { + let Some(comment) = if source_loc.has_leading_comments() { + source_loc.leading_comments().trim() + } else if source_loc.has_trailing_comments() { + source_loc.trailing_comments().trim() + } else { + continue; + } + .split("\n") + .map(str::trim) + .last() else { + continue; + }; + + let lowercased_comment = comment.to_lowercase(); + + // CmdId or CmdID + if lowercased_comment.starts_with("cmdid: ") + && let Ok(cmd_id) = comment[7..].parse::() + { + out.command_ids.insert(full_name, cmd_id); + } + // MessageId + else if lowercased_comment.starts_with("messageid: ") + && let Some(rest) = comment.get(10..) + && let Some(num_part) = rest.split_whitespace().next() + && let Ok(message_id) = num_part.parse::() + { + out.command_ids.insert(full_name, message_id); + } + } } } diff --git a/src/generator/mod.rs b/src/generator/mod.rs index 0e4a109..7bfd371 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -10,6 +10,7 @@ use std::io; use crate::generator::json_registry::{ProtoJsonRegistry, generate_json_registry}; mod json_registry; +mod utils; pub fn generate(input: &[u8]) -> io::Result { let codegen_request = CodeGeneratorRequest::parse_from_bytes(input)?; @@ -24,7 +25,7 @@ pub fn generate(input: &[u8]) -> io::Result { } let mut out = ProtoJsonRegistry::default(); - generate_json_registry(&proto.message_type, proto.package(), None, &mut out); + generate_json_registry(&proto, &proto.message_type, proto.package(), None, &mut out); let mut json_file = File::new(); json_file.set_content(serde_json::to_string_pretty(&out).unwrap()); diff --git a/src/generator/utils.rs b/src/generator/utils.rs new file mode 100644 index 0000000..3fe0e54 --- /dev/null +++ b/src/generator/utils.rs @@ -0,0 +1,27 @@ +use protobuf::descriptor::{FileDescriptorProto, source_code_info::Location}; + +pub struct SourceContext<'a> { + location_map: std::collections::HashMap, &'a Location>, +} + +impl<'a> SourceContext<'a> { + pub fn new(file: &'a FileDescriptorProto) -> Self { + let mut map = std::collections::HashMap::new(); + for loc in &file.source_code_info.location { + map.insert(loc.path.clone(), loc); + } + Self { location_map: map } + } + + pub fn get_message_loc(&self, msg_index: usize) -> Option<&Location> { + self.location_map.get(&vec![4, msg_index as i32]).copied() + } + + #[expect(unused)] + pub fn get_nested_loc(&self, parent_path: &[i32], nested_index: usize) -> Option<&Location> { + let mut path = parent_path.to_vec(); + path.push(3); + path.push(nested_index as i32); + self.location_map.get(&path).copied() + } +}