From fc4ccf35735e0232d6d542063de350cae4e528a6 Mon Sep 17 00:00:00 2001 From: Alex Rocha Date: Mon, 17 Nov 2025 18:39:04 -0800 Subject: [PATCH] Add Node and NodeList types to Rust RBS bindings --- rust/ruby-rbs/build.rs | 100 +++++++++++++++++++++++++++++++++------ rust/ruby-rbs/src/lib.rs | 36 ++++++++++++++ 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/rust/ruby-rbs/build.rs b/rust/ruby-rbs/build.rs index 2e470c7c6..3e2ab08da 100644 --- a/rust/ruby-rbs/build.rs +++ b/rust/ruby-rbs/build.rs @@ -35,6 +35,47 @@ fn main() -> Result<(), Box> { Ok(()) } +enum CIdentifier { + Type, // foo_bar_t + Constant, // FOO_BAR +} + +fn convert_name(name: &str, identifier: CIdentifier) -> String { + let type_name = name.replace("::", "_"); + let lowercase = matches!(identifier, CIdentifier::Type); + let mut out = String::new(); + let mut prev_is_lower = false; + + for ch in type_name.chars() { + if ch.is_ascii_uppercase() { + if prev_is_lower { + out.push('_'); + } + out.push(if lowercase { + ch.to_ascii_lowercase() + } else { + ch + }); + prev_is_lower = false; + } else if ch == '_' { + out.push(ch); + prev_is_lower = false; + } else { + out.push(if lowercase { + ch + } else { + ch.to_ascii_uppercase() + }); + prev_is_lower = ch.is_ascii_lowercase() || ch.is_ascii_digit(); + } + } + + if lowercase { + out.push_str("_t"); + } + out +} + fn generate(config: &Config) -> Result<(), Box> { let out_dir = env::var("OUT_DIR")?; let dest_path = Path::new(&out_dir).join("bindings.rs"); @@ -50,18 +91,11 @@ fn generate(config: &Config) -> Result<(), Box> { writeln!(file, "#[allow(dead_code)]")?; // TODO: Remove this once all nodes that need parser are implemented writeln!(file, "pub struct {} {{", node.rust_name)?; writeln!(file, " parser: *mut rbs_parser_t,")?; - if let Some(fields) = &node.fields { - for field in fields { - match field.c_type.as_str() { - "rbs_string" => writeln!(file, " {}: *const rbs_string_t,", field.name)?, - "bool" => writeln!(file, " {}: bool,", field.name)?, - "rbs_ast_symbol" => { - writeln!(file, " {}: *const rbs_ast_symbol_t,", field.name)? - } - _ => eprintln!("Unknown field type: {}", field.c_type), - } - } - } + writeln!( + file, + " pointer: *mut {},", + convert_name(&node.name, CIdentifier::Type) + )?; writeln!(file, "}}\n")?; writeln!(file, "impl {} {{", node.rust_name)?; @@ -70,19 +104,23 @@ fn generate(config: &Config) -> Result<(), Box> { match field.c_type.as_str() { "rbs_string" => { writeln!(file, " pub fn {}(&self) -> RBSString {{", field.name)?; - writeln!(file, " RBSString::new(self.{})", field.name)?; + writeln!( + file, + " RBSString::new(unsafe {{ &(*self.pointer).{} }})", + field.name + )?; writeln!(file, " }}")?; } "bool" => { writeln!(file, " pub fn {}(&self) -> bool {{", field.name)?; - writeln!(file, " self.{}", field.name)?; + writeln!(file, " unsafe {{ (*self.pointer).{} }}", field.name)?; writeln!(file, " }}")?; } "rbs_ast_symbol" => { writeln!(file, " pub fn {}(&self) -> RBSSymbol {{", field.name)?; writeln!( file, - " RBSSymbol::new(self.{}, self.parser)", + " RBSSymbol::new(unsafe {{ (*self.pointer).{} }}, self.parser)", field.name )?; writeln!(file, " }}")?; @@ -106,5 +144,37 @@ fn generate(config: &Config) -> Result<(), Box> { } writeln!(file, "}}")?; + writeln!(file, "impl Node {{")?; + writeln!(file, " #[allow(clippy::missing_safety_doc)]")?; + writeln!( + file, + " pub unsafe fn new(parser: *mut rbs_parser_t, node: *mut rbs_node_t) -> Self {{" + )?; + writeln!(file, " match unsafe {{ (*node).type_ }} {{")?; + for node in &config.nodes { + let variant_name = node + .rust_name + .strip_suffix("Node") + .unwrap_or(&node.rust_name); + + let enum_name = convert_name(&node.name, CIdentifier::Constant); + + writeln!( + file, + " rbs_node_type::{} => Self::{}({} {{ parser, pointer: node.cast::<{}>() }}),", + enum_name, + variant_name, + node.rust_name, + convert_name(&node.name, CIdentifier::Type) + )?; + } + writeln!( + file, + " _ => panic!(\"Unknown node type: {{}}\", unsafe {{ (*node).type_ }})" + )?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + writeln!(file, "}}")?; + Ok(()) } diff --git a/rust/ruby-rbs/src/lib.rs b/rust/ruby-rbs/src/lib.rs index d90496129..a81982f42 100644 --- a/rust/ruby-rbs/src/lib.rs +++ b/rust/ruby-rbs/src/lib.rs @@ -40,6 +40,42 @@ pub fn parse(rbs_code: &[u8]) -> Result<*mut rbs_signature_t, String> { } } +pub struct NodeListIter { + parser: *mut rbs_parser_t, + current: *mut rbs_node_list_node_t, +} + +impl Iterator for NodeListIter { + type Item = Node; + + fn next(&mut self) -> Option { + if self.current.is_null() { + None + } else { + let pointer_data = unsafe { *self.current }; + let node = unsafe { Node::new(self.parser, pointer_data.node) }; + self.current = pointer_data.next; + Some(node) + } + } +} + +pub struct NodeList { + parser: *mut rbs_parser_t, + pointer: *mut rbs_node_list_t, +} + +impl NodeList { + /// Returns an iterator over the nodes. + #[must_use] + pub fn iter(&self) -> NodeListIter { + NodeListIter { + parser: self.parser, + current: unsafe { (*self.pointer).head }, + } + } +} + pub struct RBSString { pointer: *const rbs_string_t, }