diff --git a/crates/core/src/types.rs b/crates/core/src/types.rs index 35d6f5f6c..3e242a016 100644 --- a/crates/core/src/types.rs +++ b/crates/core/src/types.rs @@ -5,6 +5,7 @@ use wit_parser::*; #[derive(Default)] pub struct Types { type_info: HashMap, + equal_types: UnionFind, } #[derive(Default, Clone, Copy, Debug)] @@ -91,6 +92,23 @@ impl Types { } } } + pub fn collect_equal_types(&mut self, resolve: &Resolve) { + let type_ids: Vec<_> = resolve.types.iter().map(|(id, _)| id).collect(); + for (i, &ty) in type_ids.iter().enumerate() { + // TODO: we could define a hash function for TypeDefKind to prevent the inner loop. + for &earlier in &type_ids[..i] { + if self.equal_types.find(ty) == self.equal_types.find(earlier) { + continue; + } + // The correctness of is_structurally_equal relies on the fact that + // resolve.types.iter() is in topological order. + if self.is_structurally_equal(resolve, ty, earlier) { + self.equal_types.union(ty, earlier); + break; + } + } + } + } fn type_info_func(&mut self, resolve: &Resolve, func: &Function, import: bool) { let mut live = LiveTypes::default(); @@ -228,4 +246,132 @@ impl Types { None => TypeInfo::default(), } } + fn is_structurally_equal(&mut self, resolve: &Resolve, a: TypeId, b: TypeId) -> bool { + let a_def = &resolve.types[a].kind; + let b_def = &resolve.types[b].kind; + if self.is_resource_like_type(a_def) || self.is_resource_like_type(b_def) { + return false; + } + match (a_def, b_def) { + (TypeDefKind::Type(ta), TypeDefKind::Type(tb)) => { + // This function is called in topological order, so the equivalence + // classes of ta and tb have already been computed. We can use the representative + // TypeId to check equality, instead of recursing down. + self.types_equal(resolve, ta, tb) + } + (TypeDefKind::Record(ra), TypeDefKind::Record(rb)) => { + ra.fields.len() == rb.fields.len() + // Fields are ordered in WIT, so record {a: T, b: U} is different from {b: U, a: T} + && ra.fields.iter().zip(rb.fields.iter()).all(|(fa, fb)| { + fa.name == fb.name && self.types_equal(resolve, &fa.ty, &fb.ty) + }) + } + (TypeDefKind::Variant(va), TypeDefKind::Variant(vb)) => { + va.cases.len() == vb.cases.len() + && va.cases.iter().zip(vb.cases.iter()).all(|(ca, cb)| { + ca.name == cb.name && self.optional_types_equal(resolve, &ca.ty, &cb.ty) + }) + } + (TypeDefKind::Enum(ea), TypeDefKind::Enum(eb)) => { + ea.cases.len() == eb.cases.len() + && ea + .cases + .iter() + .zip(eb.cases.iter()) + .all(|(ca, cb)| ca.name == cb.name) + } + (TypeDefKind::Flags(fa), TypeDefKind::Flags(fb)) => { + fa.flags.len() == fb.flags.len() + && fa + .flags + .iter() + .zip(fb.flags.iter()) + .all(|(fa, fb)| fa.name == fb.name) + } + (TypeDefKind::Tuple(ta), TypeDefKind::Tuple(tb)) => { + ta.types.len() == tb.types.len() + && ta + .types + .iter() + .zip(tb.types.iter()) + .all(|(a, b)| self.types_equal(resolve, a, b)) + } + (TypeDefKind::List(la), TypeDefKind::List(lb)) => self.types_equal(resolve, la, lb), + (TypeDefKind::FixedSizeList(ta, sa), TypeDefKind::FixedSizeList(tb, sb)) => { + sa == sb && self.types_equal(resolve, ta, tb) + } + (TypeDefKind::Option(oa), TypeDefKind::Option(ob)) => self.types_equal(resolve, oa, ob), + (TypeDefKind::Result(ra), TypeDefKind::Result(rb)) => { + self.optional_types_equal(resolve, &ra.ok, &rb.ok) + && self.optional_types_equal(resolve, &ra.err, &rb.err) + } + _ => false, + } + } + fn types_equal(&mut self, resolve: &Resolve, a: &Type, b: &Type) -> bool { + match (a, b) { + (Type::Id(a), Type::Id(b)) => { + let a_def = &resolve.types[*a].kind; + let b_def = &resolve.types[*b].kind; + if self.is_resource_like_type(a_def) || self.is_resource_like_type(b_def) { + return false; + } + self.equal_types.find(*a) == self.equal_types.find(*b) + } + (Type::ErrorContext, Type::ErrorContext) => todo!(), + _ => a == b, + } + } + fn optional_types_equal( + &mut self, + resolve: &Resolve, + a: &Option, + b: &Option, + ) -> bool { + match (a, b) { + (Some(a), Some(b)) => self.types_equal(resolve, a, b), + (None, None) => true, + _ => false, + } + } + fn is_resource_like_type(&self, ty: &TypeDefKind) -> bool { + match ty { + TypeDefKind::Resource | TypeDefKind::Handle(_) => true, + TypeDefKind::Future(_) | TypeDefKind::Stream(_) => true, + _ => false, + } + } + pub fn get_representative_type(&mut self, id: TypeId) -> TypeId { + self.equal_types.find(id) + } +} + +#[derive(Default)] +pub struct UnionFind { + parent: HashMap, +} +impl UnionFind { + fn find(&mut self, id: TypeId) -> TypeId { + // Path compression + let parent = self.parent.get(&id).copied().unwrap_or(id); + if parent != id { + let root = self.find(parent); + self.parent.insert(id, root); + root + } else { + id + } + } + fn union(&mut self, a: TypeId, b: TypeId) { + let ra = self.find(a); + let rb = self.find(b); + if ra != rb { + // Use smaller id as root for determinism + if ra < rb { + self.parent.insert(rb, ra); + } else { + self.parent.insert(ra, rb); + } + } + } } diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index c9a55eb14..b5184bc75 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -2326,6 +2326,38 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) } } + pub fn type_alias_to_eqaul_type( + &mut self, + id: TypeId, + eq_ty: TypeId, + from_import: Option<&WorldKey>, + ) { + assert!(self.r#gen.opts.merge_structurally_equal_types); + if let Some(name) = from_import { + let docs = Docs { + contents: Some("wit-bindgen: alias to import equal type".to_string()), + }; + let mut path = self.path_to_root(); + let import_path = crate::compute_module_path(name, self.resolve, false).join("::"); + path.push_str(&import_path); + path.push_str("::"); + for (name, mode) in self.modes_of(id) { + self.rustdoc(&docs); + self.push_str(&format!("pub type {name}")); + self.print_generics(mode.lifetime); + self.push_str(" = "); + self.push_str(&path); + self.print_tyid(eq_ty, mode); + self.push_str(";\n"); + } + } else { + let docs = Docs { + contents: Some("wit-bindgen: alias to equal type".to_string()), + }; + self.print_typedef_alias(id, &Type::Id(eq_ty), &docs); + } + } + fn print_typedef_alias(&mut self, id: TypeId, ty: &Type, docs: &Docs) { for (name, mode) in self.modes_of(id) { self.rustdoc(docs); diff --git a/crates/rust/src/lib.rs b/crates/rust/src/lib.rs index dece973ba..b48bbab34 100644 --- a/crates/rust/src/lib.rs +++ b/crates/rust/src/lib.rs @@ -274,6 +274,15 @@ pub struct Opts { #[cfg_attr(feature = "clap", clap(flatten))] #[cfg_attr(feature = "serde", serde(flatten))] pub async_: AsyncFilterSet, + + /// Find all structurally equal types and only generate one type definition for + /// each equivalence class. Other types in the same class will be type aliases to the + /// generated type. This avoids clone when converting between types that are + /// structurally equal, which is useful when import and export the same interface. + /// + /// Types containing resource, future, or stream are never considered equal. + #[cfg_attr(feature = "clap", arg(long))] + pub merge_structurally_equal_types: bool, } impl Opts { @@ -973,6 +982,42 @@ macro_rules! __export_{world_name}_impl {{ .async_ .is_async(resolve, interface, func, is_import) } + + // Returns the structurally equal type id if exists. If the equal type comes from the + // import of the same interface, also returns the interface key, so that we can generate + // a type alias to the import type. + fn get_equal_type_alias<'a>( + &mut self, + resolve: &Resolve, + iface_key: Option<&'a WorldKey>, + ty_id: TypeId, + ) -> Option<(TypeId, Option<&'a WorldKey>)> { + if !self.opts.merge_structurally_equal_types { + return None; + } + let ty = &resolve.types[ty_id].kind; + if matches!(ty, TypeDefKind::Type(_)) { + // preserve all primitive type and type alias definitions + return None; + } + let root = self.types.get_representative_type(ty_id); + if root != ty_id { + Some((root, None)) + } else { + let TypeOwner::Interface(iface_id) = resolve.types[ty_id].owner else { + unreachable!() + }; + // When we allow importing/exporting the same interface multiple times, we need to update this code + if !self.types.get(ty_id).has_resource + && iface_key.is_some() + && let Some(true) = self.interface_last_seen_as_import.get(&iface_id) + { + Some((root, iface_key)) + } else { + None + } + } + } } impl WorldGenerator for RustWasm { @@ -1054,6 +1099,9 @@ impl WorldGenerator for RustWasm { "// * disable-run-ctors-once-workaround" ); } + if self.opts.merge_structurally_equal_types { + uwriteln!(self.src_preamble, "// * merge_structurally_equal_types"); + } if let Some(s) = &self.opts.export_macro_name { uwriteln!(self.src_preamble, "// * export-macro-name: {s}"); } @@ -1073,6 +1121,9 @@ impl WorldGenerator for RustWasm { uwriteln!(self.src_preamble, "// * async: {opt}"); } self.types.analyze(resolve); + if self.opts.merge_structurally_equal_types { + self.types.collect_equal_types(resolve); + } self.world = Some(world); let world = &resolve.worlds[world]; @@ -1105,13 +1156,14 @@ impl WorldGenerator for RustWasm { let mut to_define = Vec::new(); for (name, ty_id) in resolve.interfaces[id].types.iter() { let full_name = full_wit_type_name(resolve, *ty_id); + let eq_alias = self.get_equal_type_alias(resolve, None, *ty_id); if let Some(type_gen) = self.with.get(&full_name) { // skip type definition generation for remapped types if type_gen.generated() { - to_define.push((name, ty_id)); + to_define.push((name, ty_id, eq_alias)); } } else { - to_define.push((name, ty_id)); + to_define.push((name, ty_id, eq_alias)); } self.generated_types.insert(full_name); } @@ -1129,8 +1181,12 @@ impl WorldGenerator for RustWasm { return Ok(()); } - for (name, ty_id) in to_define { - r#gen.define_type(&name, *ty_id); + for (name, ty_id, eq_alias) in to_define { + if let Some((alias, _)) = eq_alias { + r#gen.type_alias_to_eqaul_type(*ty_id, alias, None); + } else { + r#gen.define_type(&name, *ty_id); + } } r#gen.generate_imports(resolve.interfaces[id].functions.values(), Some(name)); @@ -1167,9 +1223,10 @@ impl WorldGenerator for RustWasm { _files: &mut Files, ) -> Result<()> { let mut to_define = Vec::new(); - for (name, ty_id) in resolve.interfaces[id].types.iter() { + for (ty_name, ty_id) in resolve.interfaces[id].types.iter() { let full_name = full_wit_type_name(resolve, *ty_id); - to_define.push((name, ty_id)); + let eq_alias = self.get_equal_type_alias(resolve, Some(name), *ty_id); + to_define.push((ty_name, ty_id, eq_alias)); self.generated_types.insert(full_name); } @@ -1186,8 +1243,12 @@ impl WorldGenerator for RustWasm { return Ok(()); } - for (name, ty_id) in to_define { - r#gen.define_type(&name, *ty_id); + for (ty_name, ty_id, eq_alias) in to_define { + if let Some((alias, from_import)) = eq_alias { + r#gen.type_alias_to_eqaul_type(*ty_id, alias, from_import); + } else { + r#gen.define_type(&ty_name, *ty_id); + } } let macro_name = @@ -1247,19 +1308,24 @@ impl WorldGenerator for RustWasm { let mut to_define = Vec::new(); for (name, ty_id) in types { let full_name = full_wit_type_name(resolve, *ty_id); + let eq_alias = self.get_equal_type_alias(resolve, None, *ty_id); if let Some(type_gen) = self.with.get(&full_name) { // skip type definition generation for remapped types if type_gen.generated() { - to_define.push((name, ty_id)); + to_define.push((name, ty_id, eq_alias)); } } else { - to_define.push((name, ty_id)); + to_define.push((name, ty_id, eq_alias)); } self.generated_types.insert(full_name); } let mut r#gen = self.interface(Identifier::World(world), "$root", resolve, true); - for (name, ty) in to_define { - r#gen.define_type(name, *ty); + for (name, ty, eq_alias) in to_define { + if let Some((alias, _)) = eq_alias { + r#gen.type_alias_to_eqaul_type(*ty, alias, None); + } else { + r#gen.define_type(name, *ty); + } } let src = r#gen.finish(); self.src.push_str(&src); @@ -1389,7 +1455,11 @@ impl WorldGenerator for RustWasm { } } -fn compute_module_path(name: &WorldKey, resolve: &Resolve, is_export: bool) -> Vec { +pub(crate) fn compute_module_path( + name: &WorldKey, + resolve: &Resolve, + is_export: bool, +) -> Vec { let mut path = Vec::new(); if is_export { path.push("exports".to_string()); diff --git a/tests/runtime/rust/equal-types/compose.wac b/tests/runtime/rust/equal-types/compose.wac new file mode 100644 index 000000000..585bc6302 --- /dev/null +++ b/tests/runtime/rust/equal-types/compose.wac @@ -0,0 +1,6 @@ +package example:composition; + +let host = new test:host { ... }; +let proxy = new test:proxy { ...host, ... }; +let runner = new test:runner { ...proxy, ... }; +export runner...; diff --git a/tests/runtime/rust/equal-types/host.rs b/tests/runtime/rust/equal-types/host.rs new file mode 100644 index 000000000..bdd176e1b --- /dev/null +++ b/tests/runtime/rust/equal-types/host.rs @@ -0,0 +1,34 @@ +//@ args = '--merge-structurally-equal-types' + +include!(env!("BINDINGS")); + +struct Component; + +export!(Component); + +use crate::exports::test::equal_types::blag::{Guest, Kind1, Kind3, Kind4, TStream, Tree, GuestInputStream, InputStream}; + +impl GuestInputStream for u32 { + fn read(&self, _len: u64) -> Vec { Vec::new() } +} + +impl Guest for Component { + type InputStream = u32; + fn f(x: Kind1) -> Kind1 { x } + fn g(x: Kind3) -> Kind4 { Kind4 { a: x.a } } + fn h(x: TStream) -> Tree { x.tree } +} + +use crate::exports::test::equal_types::blah::{Guest as HGuest, Kind5, Kind6, Kind7, CustomResult}; + +impl HGuest for Component { + fn f(x: Kind6) -> Kind5 { + match x { + Kind6::A => Kind1::A, + Kind6::B(x) => Kind5::B(x), + Kind6::C => Kind1::C, + } + } + fn g(x: Kind7)-> Kind4 { Kind4 { a: InputStream::new(*x.a.get::()) } } + fn h(x: TStream) -> CustomResult { CustomResult::Ok(x.tree) } +} diff --git a/tests/runtime/rust/equal-types/proxy.rs b/tests/runtime/rust/equal-types/proxy.rs new file mode 100644 index 000000000..9a205f14c --- /dev/null +++ b/tests/runtime/rust/equal-types/proxy.rs @@ -0,0 +1,36 @@ +//@ args = '--merge-structurally-equal-types' + +include!(env!("BINDINGS")); + +struct Component; + +export!(Component); + +use crate::test::equal_types::blag; +use crate::exports::test::equal_types::blag::{Guest, Kind1, Kind3, Kind4, TStream, Tree, GuestInputStream}; + +impl GuestInputStream for u32 { + fn read(&self, _len: u64) -> Vec { Vec::new() } +} + +impl Guest for Component { + type InputStream = u32; + fn f(x: Kind1) -> Kind1 { blag::f(x) } + fn g(_x: Kind3) -> Kind4 { todo!() } + fn h(x: TStream) -> Tree { + let x = blag::TStream { tree: x.tree, stream: None }; + blag::h(&x) + } +} + +use crate::test::equal_types::blah; +use crate::exports::test::equal_types::blah::{Guest as HGuest, Kind5, Kind6, Kind7, CustomResult}; + +impl HGuest for Component { + fn f(x: Kind6) -> Kind5 { blah::f(x) } + fn g(_x: Kind7)-> Kind4 { todo!() } + fn h(x: TStream) -> CustomResult { + let x = blah::TStream { tree: x.tree, stream: None }; + blah::h(&x) + } +} diff --git a/tests/runtime/rust/equal-types/runner.rs b/tests/runtime/rust/equal-types/runner.rs new file mode 100644 index 000000000..b56643f4c --- /dev/null +++ b/tests/runtime/rust/equal-types/runner.rs @@ -0,0 +1,35 @@ +//@ args = ['--merge-structurally-equal-types', '-dPartialEq', '--additional-derive-ignore=kind7', '--additional-derive-ignore=kind3', '--additional-derive-ignore=kind4', '--additional-derive-ignore=t-stream'] + +include!(env!("BINDINGS")); + +use crate::test::equal_types::{blag, blah}; + +struct Component; + +export!(Component); + +impl Guest for Component { + fn run() { + let kind1 = blag::Kind1::A; + let res1 = blag::f(kind1); + let kind6 = blah::Kind6::A; + let res2 = blah::f(kind6); + assert_eq!(res1, res2); + let t2 = blag::T2 { + l: blag::T3 { l: kind1.clone(), r: kind1.clone() }, + r: blah::T3 { l: kind1.clone(), r: kind1.clone() } + }; + let t1 = blag::T1 { + l: t2.clone(), + r: t2.clone(), + }; + let t = blag::Tree { + l: t1.clone(), + r: t1.clone(), + }; + let t_stream = blag::TStream { tree: t.clone(), stream: None }; + let res1 = blag::h(&t_stream); + let blah::CustomResult::Ok(res2) = blah::h(&t_stream) else { unreachable!() }; + assert_eq!(res1, res2); + } +} diff --git a/tests/runtime/rust/equal-types/test.wit b/tests/runtime/rust/equal-types/test.wit new file mode 100644 index 000000000..2c047d4a7 --- /dev/null +++ b/tests/runtime/rust/equal-types/test.wit @@ -0,0 +1,54 @@ +//@ dependencies = ["host", "proxy"] +//@ wac = "./compose.wac" + +package test:equal-types; + +interface blag { + variant kind1 { a, b(u64), c } + variant kind2 { a, b(u64), c } + record kind3 { a: input-stream } + record kind4 { a: input-stream } + record tree { l: t1, r: t1 } + record t1 { l: t2, r: t2 } + record t2 { l: t3, r: t3 } + record t3 { l: kind1, r: kind2 } + record t-stream { tree: tree, %stream: option> } + resource input-stream { + read: func(len: u64) -> list; + } + f: func(x: kind1) -> kind2; + g: func(x: kind3) -> kind4; + h: func(x: t-stream) -> tree; +} + +interface blah { + use blag.{input-stream, kind4, t-stream}; + variant kind5 { a, b(u64), c } + variant kind6 { a, c, b(u64) } + record kind7 { a: borrow } + record tt { l: t2, r: t2 } + record t1 { l: t3, r: t3 } + record t2 { l: t1, r: t1 } + record t3 { l: kind5, r: kind5 } + variant custom-result { ok(tt), err } + f: func(x: kind6) -> kind5; + g: func(x: kind7) -> kind4; + h: func(x: t-stream) -> custom-result; +} + +world host { + export blah; + export blag; +} +world proxy { + import blag; + export blag; + import blah; + export blah; +} +world runner { + import blag; + import blah; + export run: func(); +} +