From 02942f198e83c5c216c9596ccc9dd8b35c8479f1 Mon Sep 17 00:00:00 2001 From: Alex Rocha Date: Fri, 12 Dec 2025 21:00:24 -0800 Subject: [PATCH] Generate child node traversal in visitor functions --- rust/ruby-rbs/build.rs | 127 ++++++++++++++++++++++++++++++++++++++- rust/ruby-rbs/src/lib.rs | 109 +++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 1 deletion(-) diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs index 0c7e3c258..b2142fa86 100644 --- a/rust/ruby-rbs/build.rs +++ b/rust/ruby-rbs/build.rs @@ -185,16 +185,135 @@ fn write_visit_trait(file: &mut File, config: &Config) -> Result<(), Box `visit_type_name_node`). + let visitor_method_names: std::collections::HashMap = config + .nodes + .iter() + .map(|node| { + let c_type = convert_name(&node.name, CIdentifier::Type); + let c_type = c_type.strip_suffix("_t").unwrap_or(&c_type).to_string(); + let method = convert_name(node.variant_name(), CIdentifier::Method); + (c_type, method) + }) + .collect(); + + let is_visitable = |c_type: &str| -> bool { + matches!(c_type, "rbs_node" | "rbs_node_list" | "rbs_hash") + || visitor_method_names.contains_key(c_type) + }; + for node in &config.nodes { let node_variant_name = node.variant_name(); let method_name = convert_name(node_variant_name, CIdentifier::Method); - writeln!(file, "#[allow(unused_variables)]")?; // TODO: Remove this once all nodes that need visitor are implemented + let has_visitable_fields = node + .fields + .iter() + .flatten() + .any(|field| is_visitable(&field.c_type)); + + if !has_visitable_fields { + // If there's nothing to visit in this node, write the empty method with + // underscored parameters, and skip to the next iteration + writeln!( + file, + "pub fn visit_{method_name}_node(_visitor: &mut V, _node: &{node_variant_name}Node) {{}}" + )?; + + continue; + } + writeln!( file, "pub fn visit_{}_node(visitor: &mut V, node: &{}Node) {{", method_name, node_variant_name )?; + + if let Some(fields) = &node.fields { + for field in fields { + let field_method_name = if field.name == "type" { + "type_" + } else { + field.name.as_str() + }; + + match field.c_type.as_str() { + "rbs_node" => { + if field.optional { + writeln!( + file, + " if let Some(item) = node.{field_method_name}() {{" + )?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " visitor.visit(&node.{field_method_name}());")?; + } + } + + "rbs_node_list" => { + if field.optional { + writeln!( + file, + " if let Some(list) = node.{field_method_name}() {{" + )?; + writeln!(file, " for item in list.iter() {{")?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!(file, " for item in node.{field_method_name}().iter() {{")?; + writeln!(file, " visitor.visit(&item);")?; + writeln!(file, " }}")?; + } + } + + "rbs_hash" => { + if field.optional { + writeln!( + file, + " if let Some(hash) = node.{field_method_name}() {{" + )?; + writeln!(file, " for (key, value) in hash.iter() {{")?; + writeln!(file, " visitor.visit(&key);")?; + writeln!(file, " visitor.visit(&value);")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " for (key, value) in node.{field_method_name}().iter() {{" + )?; + writeln!(file, " visitor.visit(&key);")?; + writeln!(file, " visitor.visit(&value);")?; + writeln!(file, " }}")?; + } + } + + _ => { + if let Some(visit_method_name) = visitor_method_names.get(&field.c_type) { + if field.optional { + writeln!( + file, + " if let Some(item) = node.{field_method_name}() {{" + )?; + writeln!( + file, + " visitor.visit_{visit_method_name}_node(&item);" + )?; + writeln!(file, " }}")?; + } else { + writeln!( + file, + " visitor.visit_{visit_method_name}_node(&node.{field_method_name}());" + )?; + } + } + } + } + } + } writeln!(file, "}}")?; writeln!(file)?; } @@ -226,6 +345,12 @@ fn generate(config: &Config) -> Result<(), Box> { writeln!(file, "}}\n")?; writeln!(file, "impl {} {{", node.rust_name)?; + writeln!(file, " /// Converts this node to a generic node.")?; + writeln!(file, " #[must_use]")?; + writeln!(file, " pub fn as_node(self) -> Node {{")?; + writeln!(file, " Node::{}(self)", node.variant_name())?; + writeln!(file, " }}")?; + if let Some(fields) = &node.fields { for field in fields { match field.c_type.as_str() { diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs index cad95c118..45df7145f 100644 --- a/rust/ruby-rbs/src/lib.rs +++ b/rust/ruby-rbs/src/lib.rs @@ -326,4 +326,113 @@ mod tests { panic!("Expected TypeAlias with RecordType"); } } + + #[test] + fn visitor_test() { + struct Visitor { + visited: Vec, + } + + impl Visit for Visitor { + fn visit_bool_type_node(&mut self, node: &BoolTypeNode) { + self.visited.push("type:bool".to_string()); + + crate::visit_bool_type_node(self, node); + } + + fn visit_class_node(&mut self, node: &ClassNode) { + self.visited.push(format!( + "class:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::visit_class_node(self, node); + } + + fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) { + self.visited.push(format!( + "type:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::visit_class_instance_type_node(self, node); + } + + fn visit_class_super_node(&mut self, node: &ClassSuperNode) { + self.visited.push(format!( + "super:{}", + String::from_utf8(node.name().name().name().to_vec()).unwrap() + )); + + crate::visit_class_super_node(self, node); + } + + fn visit_function_type_node(&mut self, node: &FunctionTypeNode) { + let count = node.required_positionals().iter().count(); + self.visited + .push(format!("function:required_positionals:{count}")); + + crate::visit_function_type_node(self, node); + } + + fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) { + self.visited.push(format!( + "method:{}", + String::from_utf8(node.name().name().to_vec()).unwrap() + )); + + crate::visit_method_definition_node(self, node); + } + + fn visit_record_type_node(&mut self, node: &RecordTypeNode) { + self.visited.push("record".to_string()); + + crate::visit_record_type_node(self, node); + } + + fn visit_symbol_node(&mut self, node: &SymbolNode) { + self.visited.push(format!( + "symbol:{}", + String::from_utf8(node.name().to_vec()).unwrap() + )); + + crate::visit_symbol_node(self, node); + } + } + + let rbs_code = r#" + class Foo < Bar + def process: ({ name: String, age: Integer }, bool) -> void + end + "#; + + let signature = parse(rbs_code.as_bytes()).unwrap(); + + let mut visitor = Visitor { + visited: Vec::new(), + }; + + visitor.visit(&signature.as_node()); + + assert_eq!( + vec![ + "class:Foo", + "symbol:Foo", + "super:Bar", + "symbol:Bar", + "method:process", + "symbol:process", + "function:required_positionals:2", + "record", + "symbol:name", + "type:String", + "symbol:String", + "symbol:age", + "type:Integer", + "symbol:Integer", + "type:bool", + ], + visitor.visited + ); + } }