feat: parse commmand ids from comments

This commit is contained in:
amizing25
2026-01-11 10:08:21 +07:00
parent a6cb0a3502
commit 0f69c0508e
3 changed files with 103 additions and 6 deletions

View File

@@ -1,5 +1,11 @@
use crate::generator::utils::SourceContext;
use protobuf::{
Message, UnknownValueRef, descriptor::DescriptorProto, descriptor::field_descriptor_proto::Type,
Message, UnknownValueRef,
descriptor::{
DescriptorProto, FileDescriptorProto,
field_descriptor_proto::{Label, Type},
},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -52,12 +58,15 @@ impl From<Type> for WireType {
/// Generate json registry from given messages
pub fn generate_json_registry(
fd: &FileDescriptorProto,
messages: &[DescriptorProto],
package: &str,
parent_path: Option<&str>,
out: &mut ProtoJsonRegistry,
) {
for message in messages {
let ctx = SourceContext::new(fd);
for (i, message) in messages.iter().enumerate() {
if message.options.map_entry() {
continue;
}
@@ -77,7 +86,7 @@ pub fn generate_json_registry(
// iterate nested types
if !message.nested_type.is_empty() {
generate_json_registry(&message.nested_type, package, Some(&scoped_name), out);
generate_json_registry(fd, &message.nested_type, package, Some(&scoped_name), out);
}
// collect xor constants
@@ -91,17 +100,43 @@ pub fn generate_json_registry(
None
};
let is_repeated = field.label() == Label::LABEL_REPEATED;
let is_scalar = matches!(
field.type_(),
Type::TYPE_INT32
| Type::TYPE_INT64
| Type::TYPE_UINT32
| Type::TYPE_UINT64
| Type::TYPE_SINT32
| Type::TYPE_SINT64
| Type::TYPE_BOOL
| Type::TYPE_ENUM
| Type::TYPE_FIXED32
| Type::TYPE_SFIXED32
| Type::TYPE_FLOAT
| Type::TYPE_FIXED64
| Type::TYPE_SFIXED64
| Type::TYPE_DOUBLE
);
let is_packed =
(is_repeated && is_scalar && !field.options.has_packed()) || field.options.packed();
out.messages.entry(full_name.clone()).or_default().insert(
field.name().to_string(),
MessageField {
field_number: field.number() as u32,
wire_type: field.type_().into(),
wire_type: if is_packed {
WireType::LengthDelimited
} else {
field.type_().into()
},
xor_const: xor_const_field,
},
);
}
// collect command ids
// collect command ids from message extension
if let Some(command_id_field) = message
.options
.unknown_fields()
@@ -109,6 +144,40 @@ pub fn generate_json_registry(
&& let UnknownValueRef::Varint(varint) = command_id_field
{
out.command_ids.insert(full_name, varint as u16);
continue;
};
// collect command ids from comment
if let Some(source_loc) = ctx.get_message_loc(i) {
let Some(comment) = if source_loc.has_leading_comments() {
source_loc.leading_comments().trim()
} else if source_loc.has_trailing_comments() {
source_loc.trailing_comments().trim()
} else {
continue;
}
.split("\n")
.map(str::trim)
.last() else {
continue;
};
let lowercased_comment = comment.to_lowercase();
// CmdId or CmdID
if lowercased_comment.starts_with("cmdid: ")
&& let Ok(cmd_id) = comment[7..].parse::<u16>()
{
out.command_ids.insert(full_name, cmd_id);
}
// MessageId
else if lowercased_comment.starts_with("messageid: ")
&& let Some(rest) = comment.get(10..)
&& let Some(num_part) = rest.split_whitespace().next()
&& let Ok(message_id) = num_part.parse::<u16>()
{
out.command_ids.insert(full_name, message_id);
}
}
}
}

View File

@@ -10,6 +10,7 @@ use std::io;
use crate::generator::json_registry::{ProtoJsonRegistry, generate_json_registry};
mod json_registry;
mod utils;
pub fn generate(input: &[u8]) -> io::Result<CodeGeneratorResponse> {
let codegen_request = CodeGeneratorRequest::parse_from_bytes(input)?;
@@ -24,7 +25,7 @@ pub fn generate(input: &[u8]) -> io::Result<CodeGeneratorResponse> {
}
let mut out = ProtoJsonRegistry::default();
generate_json_registry(&proto.message_type, proto.package(), None, &mut out);
generate_json_registry(&proto, &proto.message_type, proto.package(), None, &mut out);
let mut json_file = File::new();
json_file.set_content(serde_json::to_string_pretty(&out).unwrap());

27
src/generator/utils.rs Normal file
View File

@@ -0,0 +1,27 @@
use protobuf::descriptor::{FileDescriptorProto, source_code_info::Location};
pub struct SourceContext<'a> {
location_map: std::collections::HashMap<Vec<i32>, &'a Location>,
}
impl<'a> SourceContext<'a> {
pub fn new(file: &'a FileDescriptorProto) -> Self {
let mut map = std::collections::HashMap::new();
for loc in &file.source_code_info.location {
map.insert(loc.path.clone(), loc);
}
Self { location_map: map }
}
pub fn get_message_loc(&self, msg_index: usize) -> Option<&Location> {
self.location_map.get(&vec![4, msg_index as i32]).copied()
}
#[expect(unused)]
pub fn get_nested_loc(&self, parent_path: &[i32], nested_index: usize) -> Option<&Location> {
let mut path = parent_path.to_vec();
path.push(3);
path.push(nested_index as i32);
self.location_map.get(&path).copied()
}
}