diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 701d641..06e5300 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,6 +75,10 @@ jobs: run: | sudo apt-get install -y libcurl4-openssl-dev libssl-dev Rscript -e 'install.packages(c("devtools", "rextendr"), repos="https://cloud.r-project.org")' + + - name: Add Rust Windows target + if: runner.os == 'Windows' + run: rustup target add x86_64-pc-windows-gnu # Build all bindings (Rust, Python, WASM, R) - name: Build all components diff --git a/python_bindings/src/lib.rs b/python_bindings/src/lib.rs index 2b39a67..9abec2b 100644 --- a/python_bindings/src/lib.rs +++ b/python_bindings/src/lib.rs @@ -306,15 +306,15 @@ impl PyRustDAG { .map_err(PyValueError::new_err) } - #[pyo3(signature = (start, end, include_latents=false))] + #[pyo3(signature = (starts, ends, include_latents=false))] pub fn minimal_dseparator( &self, - start: String, - end: String, + starts: Vec, + ends: Vec, include_latents: bool, ) -> PyResult>> { self.inner - .minimal_dseparator(&start, &end, include_latents) + .minimal_dseparator(starts.clone(), ends.clone(), include_latents) .map_err(PyValueError::new_err) } diff --git a/python_bindings/tests/test_dag.py b/python_bindings/tests/test_dag.py index 6312abd..9ed0ab9 100644 --- a/python_bindings/tests/test_dag.py +++ b/python_bindings/tests/test_dag.py @@ -21,45 +21,45 @@ def test_parents_children(self, dag): assert dag.get_children("X") == ["Y"] def test_minimal_dseparator(self): - # Test case: A → B ← C + # Test case: A → B → C dag1 = DAG() dag1.add_edges_from([("A", "B"), ("B", "C")]) - assert dag1.minimal_dseparator("A", "C") == {"B"} + assert dag1.minimal_dseparator(["A"], ["C"]) == {"B"} - # Test case: A → B ← C, B → D, A → E, E → D + # Test case: A → B → C, C → D, A → E, E → D dag2 = DAG() dag2.add_edges_from([("A", "B"), ("B", "C"), ("C", "D"), ("A", "E"), ("E", "D")]) - assert dag2.minimal_dseparator("A", "D") == {"C", "E"} + assert dag2.minimal_dseparator(["A"], ["D"]) == {"C", "E"} # Test case: B → A, B → C, A → D, D → C, A → E, C → E dag3 = DAG() dag3.add_edges_from([("B", "A"), ("B", "C"), ("A", "D"), ("D", "C"), ("A", "E"), ("C", "E")]) - assert dag3.minimal_dseparator("A", "C") == {"B", "D"} + assert dag3.minimal_dseparator(["A"], ["C"]) == {"B", "D"} # Test with latents dag_lat1 = DAG() dag_lat1.add_nodes_from(["A", "B", "C"], latent=[False, True, False]) dag_lat1.add_edges_from([("A", "B"), ("B", "C")]) - assert dag_lat1.minimal_dseparator("A", "C") is None - # assert dag_lat1.minimal_dseparator("A", "C", include_latents=True) == {"B"} + assert dag_lat1.minimal_dseparator(["A"], ["C"]) is None + # assert dag_lat1.minimal_dseparator(["A"], ["C"], include_latents=True) == {"B"} dag_lat2 = DAG() dag_lat2.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False]) dag_lat2.add_edges_from([("A", "D"), ("D", "B"), ("B", "C")]) - assert dag_lat2.minimal_dseparator("A", "C") == {"D"} + assert dag_lat2.minimal_dseparator(["A"], ["C"]) == {"D"} dag_lat3 = DAG() dag_lat3.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False]) dag_lat3.add_edges_from([("A", "B"), ("B", "D"), ("D", "C")]) - assert dag_lat3.minimal_dseparator("A", "C") == {"D"} + assert dag_lat3.minimal_dseparator(["A"], ["C"]) == {"D"} dag_lat4 = DAG() dag_lat4.add_nodes_from(["A", "B", "C", "D"], latent=[False, False, False, True]) dag_lat4.add_edges_from([("A", "B"), ("B", "C"), ("A", "D"), ("D", "C")]) - assert dag_lat4.minimal_dseparator("A", "C") is None + assert dag_lat4.minimal_dseparator(["A"], ["C"]) is None # Test adjacent nodes (should raise error) dag5 = DAG() dag5.add_edges_from([("A", "B")]) - with pytest.raises(ValueError, match="No possible separators because start and end are adjacent"): - dag5.minimal_dseparator("A", "B") + with pytest.raises(ValueError, match="No possible separators because A and B are adjacent"): + dag5.minimal_dseparator(["A"], ["B"]) \ No newline at end of file diff --git a/r_bindings/causalgraphs/R/extendr-wrappers.R b/r_bindings/causalgraphs/R/extendr-wrappers.R index 8431564..f0722ca 100644 --- a/r_bindings/causalgraphs/R/extendr-wrappers.R +++ b/r_bindings/causalgraphs/R/extendr-wrappers.R @@ -46,7 +46,7 @@ RDAG$are_neighbors <- function(start, end) .Call(wrap__RDAG__are_neighbors, self RDAG$get_ancestral_graph <- function(nodes) .Call(wrap__RDAG__get_ancestral_graph, self, nodes) -RDAG$minimal_dseparator <- function(start, end, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, start, end, include_latents) +RDAG$minimal_dseparator <- function(starts, ends, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, starts, ends, include_latents) #' @export `$.RDAG` <- function (self, name) { func <- RDAG[[name]]; environment(func) <- environment(); func } diff --git a/r_bindings/causalgraphs/src/rust/Cargo.toml b/r_bindings/causalgraphs/src/rust/Cargo.toml index 14ec67e..860d497 100644 --- a/r_bindings/causalgraphs/src/rust/Cargo.toml +++ b/r_bindings/causalgraphs/src/rust/Cargo.toml @@ -11,9 +11,9 @@ name = 'rcausalgraphs' [dependencies] -rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" } +# rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" } # For local development, comment out the Git line above and uncomment this: -# rust_core = { path = "../../../../rust_core" } +rust_core = { path = "../../../../rust_core" } extendr-api = '*' diff --git a/r_bindings/causalgraphs/src/rust/src/lib.rs b/r_bindings/causalgraphs/src/rust/src/lib.rs index bc092d7..f36fb57 100644 --- a/r_bindings/causalgraphs/src/rust/src/lib.rs +++ b/r_bindings/causalgraphs/src/rust/src/lib.rs @@ -213,8 +213,10 @@ impl RDAG { /// @param end Ending node /// @param include_latents Whether to include latents (default: FALSE) /// @export - fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> extendr_api::Result> { - let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + fn minimal_dseparator(&self, starts: Strings, ends: Strings, include_latents: Option) -> extendr_api::Result> { + let starts_vec: Vec = starts.iter().map(|s| s.to_string()).collect(); + let ends_vec: Vec = ends.iter().map(|s| s.to_string()).collect(); + let result = self.inner.minimal_dseparator(starts_vec.clone(), ends_vec.clone(), include_latents.unwrap_or(false)) .map_err(|e| Error::Other(e.to_string()))?; match result { Some(set) => { diff --git a/rust_core/src/dag.rs b/rust_core/src/dag.rs index 463c672..914eaf8 100644 --- a/rust_core/src/dag.rs +++ b/rust_core/src/dag.rs @@ -465,30 +465,44 @@ impl RustDAG { /// Tian, Paz, Pearl (1998), *Finding Minimal d-Separators*. pub fn minimal_dseparator( &self, - start: &str, - end: &str, + starts: Vec, + ends: Vec, include_latents: bool, ) -> Result>, String> { - // Example: For DAG A→B←C, B→D, trying to separate A and C - // Adjacent nodes can't be separated by any conditioning set - if self.has_edge(start, end) || self.has_edge(end, start) { - return Err("No possible separators because start and end are adjacent".to_string()); + // Validate inputs + if starts.is_empty() || ends.is_empty() { + return Ok(Some(HashSet::new())); } - // Create ancestral graph containing only ancestors of start and end - // Example: For separating A and D in A→B←C, B→D, ancestral graph = {A, B, C, D} - let ancestral_graph = self.get_ancestral_graph(vec![start.to_string(), end.to_string()])?; + // Check for adjacent pairs - if any start-end pair is adjacent, no separator exists + for start in &starts { + for end in &ends { + if self.has_edge(start, end) || self.has_edge(end, start) { + return Err(format!( + "No possible separators because {} and {} are adjacent", + start, end + )); + } + } + } - // Initial separator: all parents of both nodes (theoretical upper bound) - // Example: parents(A)={} ∪ parents(D)={B} → separator = {B} - let mut separator: HashSet = self - .get_parents(start)? - .into_iter() - .chain(self.get_parents(end)?.into_iter()) - .collect(); + + // Create ancestral graph containing only ancestors of all starts and ends + let mut all_nodes = starts.clone(); + all_nodes.extend(ends.clone()); + let ancestral_graph = self.get_ancestral_graph(all_nodes)?; + + // Initial separator: all parents of all start and end nodes + let mut separator: HashSet = HashSet::new(); + + for start in &starts { + separator.extend(self.get_parents(start)?); + } + for end in &ends { + separator.extend(self.get_parents(end)?); + } // Replace latent variables with their observable parents - // Example: If B were latent with parent L, replace B with L in separator if !include_latents { let mut changed = true; while changed { @@ -507,21 +521,32 @@ impl RustDAG { } } - separator.remove(start); - separator.remove(end); + // Remove starts and ends from separator (can't separate a node from itself) + for start in &starts { + separator.remove(start); + } + for end in &ends { + separator.remove(end); + } + + // Helper function to check if all start-end pairs are d-separated + let check_all_separated = |sep: &[String]| -> Result { + for start in &starts { + for end in &ends { + if ancestral_graph.is_dconnected(start, end, Some(sep.to_vec()), include_latents)? { + return Ok(false); // Found a connected pair + } + } + } + Ok(true) // All pairs are separated + }; // Sanity check: if our "guaranteed" separator doesn't work, no separator exists - if ancestral_graph.is_dconnected( - start, - end, - Some(separator.iter().cloned().collect()), - include_latents, - )? { + if !check_all_separated(&separator.iter().cloned().collect::>())? { return Ok(None); } // Greedy minimization: remove each node if separation still holds without it - // Example: If separator = {B, C} but {B} alone separates A from D, remove C let mut minimal_separator = separator.clone(); for u in separator { let test_separator: Vec = minimal_separator @@ -530,8 +555,8 @@ impl RustDAG { .filter(|x| x != &u) .collect(); - // If still d-separated WITHOUT this node, we can remove it - if !ancestral_graph.is_dconnected(start, end, Some(test_separator), include_latents)? { + // If all pairs are still d-separated WITHOUT this node, we can remove it + if check_all_separated(&test_separator)? { minimal_separator.remove(&u); } } @@ -694,6 +719,96 @@ impl Graph for RustDAG { self.get_ancestors_of(nodes) .map_err(|e| GraphError::NodeNotFound(e)) } + + fn is_dconnected( + &self, + start: &str, + end: &str, + observed: Option>, + include_latents: bool, + ) -> Result { + self.is_dconnected(start, end, observed, include_latents) + .map_err(|e| GraphError::NodeNotFound(e)) + } + + fn minimal_dseparator( + &self, + start: Vec, + end: Vec, + include_latents: bool, + ) -> Result>, GraphError> { + self.minimal_dseparator(start, end, include_latents) + .map_err(|e: String| GraphError::NodeNotFound(e)) + } + + + + fn all_simple_edge_paths( + &self, + source: &str, + target: &str, + ) -> Result>, GraphError> { + let source_idx = self + .node_map + .get(source) + .ok_or_else(|| GraphError::NodeNotFound(source.to_string()))?; + let target_idx = self + .node_map + .get(target) + .ok_or_else(|| GraphError::NodeNotFound(target.to_string()))?; + + let mut paths: Vec> = Vec::new(); + let mut current_path: Vec<(String, String)> = Vec::new(); + let mut visited: HashSet = HashSet::new(); + + fn dfs( + graph: &RustDAG, + current: NodeIndex, + target: NodeIndex, + visited: &mut HashSet, + current_path: &mut Vec<(String, String)>, + paths: &mut Vec>, + ) { + if current == target { + paths.push(current_path.clone()); + return; + } + + for neighbor in graph.graph.neighbors_directed(current, Direction::Outgoing) { + if !visited.contains(&neighbor) { + let source_name = graph.reverse_node_map[¤t].clone(); + let target_name = graph.reverse_node_map[&neighbor].clone(); + current_path.push((source_name, target_name)); + visited.insert(neighbor); + dfs(graph, neighbor, target, visited, current_path, paths); + visited.remove(&neighbor); + current_path.pop(); + } + } + } + + visited.insert(*source_idx); + dfs(self, *source_idx, *target_idx, &mut visited, &mut current_path, &mut paths); + Ok(paths) + } + + fn remove_edges_from(&self, edges: Vec<(String, String)>) -> Result { + let mut new_graph = self.clone(); + for (u, v) in edges { + let u_idx = new_graph + .node_map + .get(&u) + .ok_or_else(|| GraphError::NodeNotFound(u.clone()))?; + let v_idx = new_graph + .node_map + .get(&v) + .ok_or_else(|| GraphError::NodeNotFound(v.clone()))?; + if let Some(edge_idx) = new_graph.graph.find_edge(*u_idx, *v_idx) { + new_graph.graph.remove_edge(edge_idx); + } + } + Ok(new_graph) + } } impl GraphRoles for RustDAG { diff --git a/rust_core/src/graph.rs b/rust_core/src/graph.rs index 488df17..7e50958 100644 --- a/rust_core/src/graph.rs +++ b/rust_core/src/graph.rs @@ -2,7 +2,7 @@ use crate::graph_role::GraphError; use std::collections::HashSet; /// Trait for core graph operations required by causal graphs. -pub trait Graph { +pub trait Graph: Clone { /// Get all nodes in the graph. fn nodes(&self) -> Vec; @@ -11,4 +11,30 @@ pub trait Graph { /// Get the ancestors of a set of nodes (including the nodes themselves). fn ancestors(&self, nodes: Vec) -> Result, GraphError>; + + /// Check if two nodes are d-connected given an optional set of observed nodes. + fn is_dconnected( + &self, + start: &str, + end: &str, + observed: Option>, + include_latents: bool, + ) -> Result; + + fn minimal_dseparator( + &self, + start: Vec, + end: Vec, + include_latents: bool + ) -> Result>, GraphError>; + + /// Get all simple directed edge paths from source to target. + fn all_simple_edge_paths( + &self, + source: &str, + target: &str, + ) -> Result>, GraphError>; + + /// Remove a list of edges from the graph, returning a new graph. + fn remove_edges_from(&self, edges: Vec<(String, String)>) -> Result; } \ No newline at end of file diff --git a/rust_core/src/graph_role.rs b/rust_core/src/graph_role.rs index 122cb70..331220d 100644 --- a/rust_core/src/graph_role.rs +++ b/rust_core/src/graph_role.rs @@ -60,66 +60,48 @@ pub trait GraphRoles: Clone { .unwrap_or(false) } - /// Assign role to variables. Modifies in place if `inplace=true`, otherwise returns a new graph. - fn with_role(&mut self, role: String, variables: Vec, inplace: bool) -> Result { - if inplace { - // Modify self directly - for var in &variables { - if !self.has_node(var) { - return Err(GraphError::NodeNotFound(var.clone())); - } - } - let roles_map = self.get_roles_map_mut(); - let entry = roles_map.entry(role).or_insert(HashSet::new()); - for var in variables { - entry.insert(var); - } - Ok(self.clone()) // Return self.clone() for consistency, but self is modified - } else { - // Create and modify a new graph - let mut new_graph = self.clone(); - for var in &variables { - if !new_graph.has_node(var) { - return Err(GraphError::NodeNotFound(var.clone())); - } - } - let roles_map = new_graph.get_roles_map_mut(); - let entry = roles_map.entry(role).or_insert(HashSet::new()); - for var in variables { - entry.insert(var); + /// Assign role to variables in-place, modifying the graph. + fn with_role(&mut self, role: String, variables: Vec) -> Result<(), GraphError> { + for var in &variables { + if !self.has_node(var) { + return Err(GraphError::NodeNotFound(var.clone())); } - Ok(new_graph) } + let roles_map = self.get_roles_map_mut(); + let entry = roles_map.entry(role).or_insert(HashSet::new()); + for var in variables { + entry.insert(var); + } + Ok(()) } - /// Remove role from variables (or all if None). Modifies in place if `inplace=true`, otherwise returns a new graph. - fn without_role(&mut self, role: &str, variables: Option>, inplace: bool) -> Self { - if inplace { - if let Some(set) = self.get_roles_map_mut().get_mut(role) { - if let Some(vars) = variables { - for var in vars { - set.remove(&var); - } - } else { - set.clear(); - } - } - self.clone() // Return self.clone() for consistency - } else { - let mut new_graph = self.clone(); - if let Some(set) = new_graph.get_roles_map_mut().get_mut(role) { - if let Some(vars) = variables { - for var in vars { - set.remove(&var); - } - } else { - set.clear(); + /// Assign role to variables, returning a new graph without modifying the original. + fn with_role_copy(&self, role: String, variables: Vec) -> Result { + let mut new_graph = self.clone(); + new_graph.with_role(role, variables)?; + Ok(new_graph) + } + + /// Remove role from variables (or all if None) in-place. + fn without_role(&mut self, role: &str, variables: Option>) -> () { + if let Some(set) = self.get_roles_map_mut().get_mut(role) { + if let Some(vars) = variables { + for var in vars { + set.remove(&var); } + } else { + set.clear(); } - new_graph } } + /// Remove role from variables (or all if None), returning a new graph. + fn without_role_copy(&self, role: &str, variables: Option>) -> Self { + let mut new_graph = self.clone(); + new_graph.without_role(role, variables); + new_graph + } + /// Validate causal structure (has exposure and outcome). fn is_valid_causal_structure(&self) -> Result { let has_exposure = self.has_role("exposure"); diff --git a/rust_core/src/identification/frontdoor.rs b/rust_core/src/identification/frontdoor.rs new file mode 100644 index 0000000..088e478 --- /dev/null +++ b/rust_core/src/identification/frontdoor.rs @@ -0,0 +1,209 @@ +use crate::identification::base::BaseIdentification; +use crate::dag::RustDAG; +use crate::graph::Graph; +use crate::graph_role::{GraphError, GraphRoles}; +use std::collections::{HashMap, HashSet}; +use itertools::Itertools; // For powerset + +/// Adjustment class to validate backdoor adjustment sets. +pub struct Adjustment { + variant: String, +} + +impl Adjustment { + pub fn new(variant: &str) -> Self { + Adjustment { + variant: variant.to_string(), + } + } + + /// Validate if the adjustment set blocks all backdoor paths from exposure to outcome. + pub fn validate( + &self, + causal_graph: &T, + ) -> Result { + let exposure = causal_graph.get_role("exposure"); + let outcome = causal_graph.get_role("outcome"); + let adjustment = causal_graph.get_role("adjustment"); + + if exposure.is_empty() || outcome.is_empty() { + return Err(GraphError::InvalidOperation( + "Exposure and outcome roles must be defined".to_string(), + )); + } + + if exposure.len() > 1 || outcome.len() > 1 { + return Err(GraphError::InvalidOperation( + "Adjustment validation supports only single exposure and outcome".to_string(), + )); + } + + let exposure_str = exposure.first().unwrap(); + let outcome_str = outcome.first().unwrap(); + + // Remove all outgoing edges from exposure to check only backdoor paths + let edges_to_remove: Vec<(String, String)> = causal_graph + .nodes() + .into_iter() + .filter_map(|node| { + if causal_graph.parents(&node).ok()?.contains(&exposure_str.to_string()) { + Some((exposure_str.clone(), node)) + } else { + None + } + }) + .collect(); + + let graph_without_forward_edges = causal_graph.remove_edges_from(edges_to_remove)?; + + // Check if there's any unblocked backdoor path + // include_latents=true is critical - we need to check paths through latent confounders + let has_unblocked_backdoor = graph_without_forward_edges.is_dconnected( + exposure_str, + outcome_str, + Some(adjustment.clone()), + true // MUST be true to include latent variables + )?; + + // Valid if no unblocked backdoor paths exist + Ok(!has_unblocked_backdoor) + } +} + +/// Frontdoor identification for causal graphs. +pub struct Frontdoor { + variant: Option, // None or "all" +} + +impl Frontdoor { + /// Create a new Frontdoor instance. + pub fn new(variant: Option) -> Self { + Frontdoor { variant } + } + + /// Validate a frontdoor set in a causal graph. + pub fn validate( + &self, + causal_graph: &T, + ) -> Result { + let exposure = causal_graph.get_role("exposure"); + let outcome = causal_graph.get_role("outcome"); + let frontdoor = causal_graph.get_role("frontdoor"); + + if exposure.is_empty() || outcome.is_empty() { + return Err(GraphError::InvalidOperation( + "Exposure and outcome roles must be defined".to_string(), + )); + } + + if exposure.len() > 1 || outcome.len() > 1 { + return Err(GraphError::InvalidOperation( + "Frontdoor identification supports only single exposure and outcome".to_string(), + )); + } + + let exposure = exposure.first().unwrap(); + let outcome = outcome.first().unwrap(); + + println!("Validating frontdoor: exposure={}, outcome={}, frontdoor={:?}", exposure, outcome, frontdoor); + + // 0. Check for directed paths from X to Y + let directed_paths = causal_graph.all_simple_edge_paths(exposure, outcome)?; + println!("Step 0: directed_paths count = {}", directed_paths.len()); + if directed_paths.is_empty() { + return Ok(false); + } + + // 1. Z intercepts all directed paths from X to Y + let unblocked_paths: Vec<_> = directed_paths + .into_iter() + .filter(|path| !path.iter().any(|(_, v)| frontdoor.contains(v))) + .collect(); + println!("Step 1: unblocked_paths count = {}", unblocked_paths.len()); + if !unblocked_paths.is_empty() { + return Ok(false); + } + + // 2. No backdoor path from X to Z + let adjustment = Adjustment::new("minimal"); + // In Frontdoor::validate, step 2: + for z in &frontdoor { + let mut graph_copy = causal_graph.clone(); + graph_copy = graph_copy.without_role_copy("exposure", None); + graph_copy = graph_copy.without_role_copy("outcome", None); + graph_copy = graph_copy.without_role_copy("adjustment", None); + + graph_copy = graph_copy.with_role_copy("exposure".to_string(), vec![exposure.clone()])?; + graph_copy = graph_copy.with_role_copy("outcome".to_string(), vec![z.clone()])?; + graph_copy = graph_copy.with_role_copy("adjustment".to_string(), vec![])?; + + let is_valid = adjustment.validate(&graph_copy)?; + if !is_valid { + return Ok(false); + } + } + + // 3. All backdoor paths from Z to Y are blocked by X + for z in &frontdoor { + let mut graph_copy = causal_graph.clone(); + graph_copy = graph_copy.without_role_copy("exposure", None); + graph_copy = graph_copy.without_role_copy("outcome", None); + graph_copy = graph_copy.without_role_copy("adjustment", None); + graph_copy = graph_copy.with_role_copy("exposure".to_string(), vec![z.clone()])?; + graph_copy = graph_copy.with_role_copy("outcome".to_string(), vec![outcome.clone()])?; + graph_copy = graph_copy.with_role_copy("adjustment".to_string(), vec![exposure.clone()])?; + + let is_valid = adjustment.validate(&graph_copy)?; + if !is_valid { + return Ok(false); + } + } + + Ok(true) + } +} + +impl BaseIdentification for Frontdoor { + fn _identify( + &self, + causal_graph: &T, + ) -> Result<(T, bool), GraphError> { + let exposure = causal_graph.get_role("exposure"); + let outcome = causal_graph.get_role("outcome"); + + if exposure.is_empty() || outcome.is_empty() { + return Err(GraphError::InvalidOperation( + "Exposure and outcome roles must be defined".to_string(), + )); + } + + // Get possible frontdoor variables: observed nodes excluding exposure and outcome + let possible_frontdoor: HashSet = causal_graph + .nodes() + .into_iter() + .filter(|n| !causal_graph.get_role("exposure").contains(n)) + .filter(|n| !causal_graph.get_role("outcome").contains(n)) + .filter(|n| !causal_graph.get_role("latents").contains(n)) + .collect(); + + // Generate powerset of possible frontdoor variables + let mut valid_frontdoor_graphs = Vec::new(); + for s in possible_frontdoor.into_iter().powerset() { + let s_vec: Vec = s.into_iter().collect(); + let updated_graph = causal_graph.with_role_copy("frontdoor".to_string(), s_vec.clone())?; + if self.validate(&updated_graph)? { + if self.variant.is_none() { + return Ok((updated_graph, true)); + } else if self.variant.as_deref() == Some("all") { + valid_frontdoor_graphs.push(updated_graph); + } + } + } + + if valid_frontdoor_graphs.is_empty() { + Ok((causal_graph.clone(), false)) + } else { + Ok((valid_frontdoor_graphs[0].clone(), true)) + } + } +} \ No newline at end of file diff --git a/rust_core/src/identification/mod.rs b/rust_core/src/identification/mod.rs index 1d63ce9..c7cce86 100644 --- a/rust_core/src/identification/mod.rs +++ b/rust_core/src/identification/mod.rs @@ -1,3 +1,5 @@ pub mod base; +pub use base::BaseIdentification; -pub use base::BaseIdentification; \ No newline at end of file +pub mod frontdoor; +pub use frontdoor::Frontdoor; diff --git a/rust_core/tests/base_tests.rs b/rust_core/tests/base_tests.rs index 98ded7e..ae81aea 100644 --- a/rust_core/tests/base_tests.rs +++ b/rust_core/tests/base_tests.rs @@ -1,8 +1,5 @@ use std::collections::{HashMap, HashSet}; -use rust_core::{graph::Graph, graph_role::{GraphError, GraphRoles}, identification::base, RustDAG}; - -use base::BaseIdentification; - +use rust_core::{graph::Graph, graph_role::{GraphError, GraphRoles}, identification::base::BaseIdentification, RustDAG}; /// A simple identification method that assigns the "adjustment" role to either /// the first or last non-exposure, non-outcome node (alphabetically sorted), @@ -11,6 +8,7 @@ use base::BaseIdentification; struct DummyIdentification { variant: Option, } + impl DummyIdentification { fn new(variant: Option<&str>) -> Self { DummyIdentification { @@ -18,31 +16,31 @@ impl DummyIdentification { } } } + impl BaseIdentification for DummyIdentification { - fn _identify(&self, causal_graph: &T) -> Result<(T, bool), GraphError> { - let mut mutable_graph = causal_graph.clone(); + fn _identify(&self, causal_graph: &T) -> Result<(T, bool), GraphError> { + let non_role_nodes: HashSet = causal_graph + .nodes() + .into_iter() + .collect::>() + .difference( + &causal_graph + .get_role("exposure") + .into_iter() + .chain(causal_graph.get_role("outcome").into_iter()) + .collect::>(), + ) + .cloned() + .collect(); + let mut sorted_nodes: Vec = non_role_nodes.into_iter().collect(); + sorted_nodes.sort(); + match self.variant.as_deref() { Some("first") => { - let non_role_nodes: HashSet = causal_graph - .nodes() - .into_iter() - .collect::>() - .difference( - &causal_graph - .get_role("exposure") - .into_iter() - .chain(causal_graph.get_role("outcome").into_iter()) - .collect::>(), - ) - .cloned() - .collect(); - let mut sorted_nodes: Vec = non_role_nodes.into_iter().collect(); - sorted_nodes.sort(); if let Some(adjustment_node) = sorted_nodes.first() { - let identified_cg = mutable_graph.with_role( + let identified_cg = causal_graph.with_role_copy( "adjustment".to_string(), vec![adjustment_node.clone()], - false, )?; Ok((identified_cg, true)) } else { @@ -50,26 +48,10 @@ impl BaseIdentification for DummyIdentification { } } Some("last") => { - let non_role_nodes: HashSet = causal_graph - .nodes() - .into_iter() - .collect::>() - .difference( - &causal_graph - .get_role("exposure") - .into_iter() - .chain(causal_graph.get_role("outcome").into_iter()) - .collect::>(), - ) - .cloned() - .collect(); - let mut sorted_nodes: Vec = non_role_nodes.into_iter().collect(); - sorted_nodes.sort(); if let Some(adjustment_node) = sorted_nodes.last() { - let identified_cg = mutable_graph.with_role( + let identified_cg = causal_graph.with_role_copy( "adjustment".to_string(), vec![adjustment_node.clone()], - false, )?; Ok((identified_cg, true)) } else { @@ -81,7 +63,6 @@ impl BaseIdentification for DummyIdentification { } } - #[test] fn test_base_identification_first() { let mut cg = RustDAG::new(); @@ -96,8 +77,8 @@ fn test_base_identification_first() { ) .unwrap(); - cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); - cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + cg.with_role("exposure".to_string(), vec!["X".to_string()]).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()]).unwrap(); let identifier = DummyIdentification::new(Some("first")); let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); assert!(is_identified); @@ -111,7 +92,6 @@ fn test_base_identification_first() { assert_eq!(identified_cg.get_role_dict(), expected_roles); } - #[test] fn test_base_identification_last() { let mut cg = RustDAG::new(); @@ -125,8 +105,8 @@ fn test_base_identification_last() { None, ) .unwrap(); - cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); - cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + cg.with_role("exposure".to_string(), vec!["X".to_string()]).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()]).unwrap(); let identifier = DummyIdentification::new(Some("last")); let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); assert!(is_identified); @@ -153,8 +133,8 @@ fn test_base_identification_gibberish() { None, ) .unwrap(); - cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); - cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + cg.with_role("exposure".to_string(), vec!["X".to_string()]).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()]).unwrap(); let identifier = DummyIdentification::new(Some("gibberish")); let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); assert!(!is_identified); diff --git a/rust_core/tests/frontdoor_tests.rs b/rust_core/tests/frontdoor_tests.rs new file mode 100644 index 0000000..079f0f0 --- /dev/null +++ b/rust_core/tests/frontdoor_tests.rs @@ -0,0 +1,115 @@ +use rust_core::identification::base::BaseIdentification; +use rust_core::dag::RustDAG; +use rust_core::graph::Graph; +use rust_core::graph_role::{GraphError, GraphRoles}; +use std::collections::{HashMap, HashSet}; +use itertools::Itertools; // For powerset + + +#[cfg(test)] +mod tests { + use rust_core::identification::Frontdoor; + + use super::*; + + fn create_frontdoor_model() -> RustDAG { + let mut dag = RustDAG::new(); + dag.add_nodes_from(vec!["X".to_string(), "M".to_string(), "Y".to_string()], None) + .unwrap(); + dag.add_edges_from(vec![("X".to_string(), "M".to_string()), ("M".to_string(), "Y".to_string())], None) + .unwrap(); + dag.with_role("exposure".to_string(), vec!["X".to_string()]) + .unwrap(); + dag.with_role("outcome".to_string(), vec!["Y".to_string()]) + .unwrap(); + dag + } + + fn create_frontdoor_model_latent() -> RustDAG { + let mut dag = RustDAG::new(); + dag.add_nodes_from( + vec!["X".to_string(), "M".to_string(), "Y".to_string(), "U".to_string()], + Some(vec![false, false, false, true]), + ) + .unwrap(); + dag.add_edges_from( + vec![ + ("X".to_string(), "M".to_string()), + ("M".to_string(), "Y".to_string()), + ("U".to_string(), "X".to_string()), + ("U".to_string(), "Y".to_string()), + ], + None, + ) + .unwrap(); + dag.with_role("exposure".to_string(), vec!["X".to_string()]) + .unwrap(); + dag.with_role("outcome".to_string(), vec!["Y".to_string()]) + .unwrap(); + dag + } + + fn create_frontdoor_model_noniden() -> RustDAG { + let mut dag = RustDAG::new(); + dag.add_nodes_from( + vec!["X".to_string(), "M".to_string(), "Y".to_string(), "U".to_string()], + Some(vec![false, false, false, true]), + ) + .unwrap(); + dag.add_edges_from( + vec![ + ("X".to_string(), "M".to_string()), + ("M".to_string(), "Y".to_string()), + ("U".to_string(), "X".to_string()), + ("U".to_string(), "Y".to_string()), + ("U".to_string(), "M".to_string()), + ], + None, + ) + .unwrap(); + dag.with_role("exposure".to_string(), vec!["X".to_string()]) + .unwrap(); + dag.with_role("outcome".to_string(), vec!["Y".to_string()]) + .unwrap(); + dag + } + + #[test] + fn test_frontdoor() { + let dag = create_frontdoor_model(); + let frontdoor = Frontdoor::new(None); + let (identified_dag, is_identified) = frontdoor.identify(&dag).unwrap(); + + assert!(is_identified); + assert_eq!(identified_dag.get_role("exposure"), vec!["X"]); + assert_eq!(identified_dag.get_role("outcome"), vec!["Y"]); + assert_eq!(identified_dag.get_role("frontdoor"), vec!["M"]); + assert_eq!(identified_dag.latents, HashSet::new()); + } + + #[test] + fn test_frontdoor_latent() { + let dag = create_frontdoor_model_latent(); + let frontdoor = Frontdoor::new(None); + let (identified_dag, is_identified) = frontdoor.identify(&dag).unwrap(); + + assert!(is_identified); + assert_eq!(identified_dag.get_role("exposure"), vec!["X"]); + assert_eq!(identified_dag.get_role("outcome"), vec!["Y"]); + assert_eq!(identified_dag.get_role("frontdoor"), vec!["M"]); + assert_eq!(identified_dag.latents, HashSet::from_iter(vec!["U".to_string()])); + } + + #[test] + fn test_frontdoor_noniden() { + let dag = create_frontdoor_model_noniden(); + let frontdoor = Frontdoor::new(None); + let (identified_dag, is_identified) = frontdoor.identify(&dag).unwrap(); + + assert!(!is_identified); + assert_eq!(identified_dag.get_role("exposure"), vec!["X"]); + assert_eq!(identified_dag.get_role("outcome"), vec!["Y"]); + assert_eq!(identified_dag.get_role("frontdoor"), Vec::::new()); + assert_eq!(identified_dag.latents, HashSet::from_iter(vec!["U".to_string()])); + } +} \ No newline at end of file diff --git a/rust_core/tests/test_dag.rs b/rust_core/tests/test_dag.rs index 74eceb4..4e39bba 100644 --- a/rust_core/tests/test_dag.rs +++ b/rust_core/tests/test_dag.rs @@ -128,7 +128,7 @@ fn test_minimal_dseparator_simple() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["B".to_string()].into_iter().collect(); assert_eq!(result, Some(expected)); } @@ -147,7 +147,7 @@ fn test_minimal_dseparator_complex() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "D", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["D".to_string()], false).unwrap(); let expected: HashSet = vec!["C".to_string(), "E".to_string()].into_iter().collect(); assert_eq!(result, Some(expected)); } @@ -167,7 +167,7 @@ fn test_minimal_dseparator_latent_case_1() { ) .unwrap(); // No d-separator should exist because B is latent - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); assert_eq!(result, None); } @@ -188,7 +188,7 @@ fn test_minimal_dseparator_latent_case_2() { ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["D".to_string()].into_iter().collect(); assert_eq!( result, @@ -214,14 +214,12 @@ fn test_minimal_dseparator_latent_case_3() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); assert_eq!(result, None, "Expected no d-separator when D is latent with multiple paths A→B→C and A→D→C (D is unobservable, because of its latent status)"); } #[test] fn test_minimal_dseparator_latent_case_5() { - // dag_lat5 = DAG([("A", "B"), ("B", "C"), ("A", "D"), ("D", "E"), ("E", "C")], latents={"E"}) - // self.assertEqual(dag_lat5.minimal_dseparator(start="A", end="C"), {"B", "D"}) let mut dag = RustDAG::new(); dag.add_node("A".to_string(), false).unwrap(); dag.add_node("B".to_string(), false).unwrap(); @@ -239,12 +237,12 @@ fn test_minimal_dseparator_latent_case_5() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["B".to_string(), "D".to_string()].into_iter().collect(); assert_eq!( result, Some(expected), - "Expected [B, D] to d-separate A and C when E is latent(Observe B & parent of C => D)" + "Expected [B, D] to d-separate A and C when E is latent (Observe B & parent of C => D)" ); } @@ -254,9 +252,9 @@ fn test_minimal_dseparator_adjacent_error() { dag.add_edges_from(vec![("A".to_string(), "B".to_string())], None) .unwrap(); - let result = dag.minimal_dseparator("A", "B", false); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["B".to_string()], false); assert!(result.is_err()); - assert!(result.unwrap_err().contains("adjacent")); + assert!(result.unwrap_err().to_string().contains("adjacent")); } #[test] @@ -281,65 +279,63 @@ fn test_role_hash_equality() { } // Helper function to calculate hash value as u64 - fn calculate_hash(t: &T) -> u64 { - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - t.hash(&mut hasher); - hasher.finish() - } +fn calculate_hash(t: &T) -> u64 { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + t.hash(&mut hasher); + hasher.finish() +} #[test] - fn test_hash() { - // Create two identical M-bias DAGs: A -> B <- C, B -> D, E -> D - let mut dag1 = RustDAG::new(); - dag1.add_edges_from( - vec![ - ("A".to_string(), "B".to_string()), - ("C".to_string(), "B".to_string()), - ("B".to_string(), "D".to_string()), - ("E".to_string(), "D".to_string()), - ], - None, - ) - .unwrap(); - - let mut dag2 = RustDAG::new(); - dag2.add_edges_from( - vec![ - ("A".to_string(), "B".to_string()), - ("C".to_string(), "B".to_string()), - ("B".to_string(), "D".to_string()), - ("E".to_string(), "D".to_string()), - ], - None, - ) - .unwrap(); - - // Test identical DAGs have same hash - assert_eq!(dag1, dag2); - assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); +fn test_hash() { + // Create two identical M-bias DAGs: A -> B <- C, B -> D, E -> D + let mut dag1 = RustDAG::new(); + dag1.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "D".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); - // Add exposure role to dag1 - dag1.with_role("exposure".to_string(), vec!["E".to_string()], true).unwrap(); - assert_ne!(dag1, dag2); - assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); + let mut dag2 = RustDAG::new(); + dag2.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "D".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); - // Add exposure role to dag2 - dag2.with_role("exposure".to_string(), vec!["E".to_string()], true).unwrap(); - assert_eq!(dag1, dag2); - assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); + // Test identical DAGs have same hash + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); - // Add outcome role to dag1 - dag1.with_role("outcome".to_string(), vec!["D".to_string()], true).unwrap(); - assert_ne!(dag1, dag2); - assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); + // Add exposure role to dag1 + dag1.with_role("exposure".to_string(), vec!["E".to_string()]).unwrap(); + assert_ne!(dag1, dag2); + assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); - // Add outcome role to dag2 - dag2.with_role("outcome".to_string(), vec!["D".to_string()], true).unwrap(); - assert_eq!(dag1, dag2); - assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); - } + // Add exposure role to dag2 + dag2.with_role("exposure".to_string(), vec!["E".to_string()]).unwrap(); + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); + // Add outcome role to dag1 + dag1.with_role("outcome".to_string(), vec!["D".to_string()]).unwrap(); + assert_ne!(dag1, dag2); + assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); + // Add outcome role to dag2 + dag2.with_role("outcome".to_string(), vec!["D".to_string()]).unwrap(); + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); +} #[test] fn test_roles() { @@ -357,19 +353,19 @@ fn test_roles() { .unwrap(); // Test assigning role to existing node - let result = dag.with_role("exposure".to_string(), vec!["A".to_string()], true); + let result = dag.with_role("exposure".to_string(), vec!["A".to_string()]); assert!(result.is_ok()); assert!(dag.has_role("exposure")); assert_eq!(dag.get_role("exposure"), vec!["A".to_string()]); // Test assigning role to non-existent node (should fail) - let result = dag.with_role("exposure".to_string(), vec!["Z".to_string()], true); + let result = dag.with_role("exposure".to_string(), vec!["Z".to_string()]); assert!(matches!(result, Err(GraphError::NodeNotFound(ref s)) if s == "Z")); // Verify exposure role still contains only "A" assert_eq!(dag.get_role("exposure"), vec!["A".to_string()]); // Test assigning outcome role - let result = dag.with_role("outcome".to_string(), vec!["D".to_string()], true); + let result = dag.with_role("outcome".to_string(), vec!["D".to_string()]); assert!(result.is_ok()); assert!(dag.has_role("outcome")); assert_eq!(dag.get_role("outcome"), vec!["D".to_string()]); @@ -378,7 +374,102 @@ fn test_roles() { assert!(dag.is_valid_causal_structure().is_ok()); // Test removing role - dag.without_role("exposure", None, true); + dag.without_role("exposure", None); assert!(!dag.has_role("exposure")); assert!(dag.is_valid_causal_structure().is_err()); } + +#[test] +fn test_multiple_chains_separation() { + // Graph: A1→B1→C1, A2→B2→C2, A3→B3→C3 (three independent chains) + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A1".to_string(), "B1".to_string()), + ("B1".to_string(), "C1".to_string()), + ("A2".to_string(), "B2".to_string()), + ("B2".to_string(), "C2".to_string()), + ("A3".to_string(), "B3".to_string()), + ("B3".to_string(), "C3".to_string()), + ], + None, + ).unwrap(); + + // Separate {A1, A2, A3} from {C1, C2, C3} + let result = dag.minimal_dseparator( + vec!["A1".to_string(), "A2".to_string(), "A3".to_string()], + vec!["C1".to_string(), "C2".to_string(), "C3".to_string()], + false + ).unwrap(); + + let expected: HashSet = vec![ + "B1".to_string(), "B2".to_string(), "B3".to_string() + ].into_iter().collect(); + assert_eq!(result, Some(expected)); +} + +#[test] +fn test_fork_convergence_pattern() { + // Graph: A→B, A→C, B→D, C→D (fork from A, convergence at D) + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("A".to_string(), "C".to_string()), + ("B".to_string(), "D".to_string()), + ("C".to_string(), "D".to_string()), + ], + None, + ).unwrap(); + + // Separate {B, C} from {D} - should need both B and C + let result = dag.minimal_dseparator( + vec!["B".to_string(), "C".to_string()], + vec!["D".to_string()], + false + ); + + // No separator should exist because B→D and C→D are direct edges + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("adjacent")); +} + +#[test] +fn test_diamond_pattern_multiple() { + // Graph: Multiple diamond patterns + // A1→B1, A1→C1, B1→D1, C1→D1 + // A2→B2, A2→C2, B2→D2, C2→D2 + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A1".to_string(), "B1".to_string()), + ("A1".to_string(), "C1".to_string()), + ("B1".to_string(), "D1".to_string()), + ("C1".to_string(), "D1".to_string()), + ("A2".to_string(), "B2".to_string()), + ("A2".to_string(), "C2".to_string()), + ("B2".to_string(), "D2".to_string()), + ("C2".to_string(), "D2".to_string()), + ], + None, + ).unwrap(); + + // Separate {A1, A2} from {D1, D2} + let result = dag.minimal_dseparator( + vec!["A1".to_string(), "A2".to_string()], + vec!["D1".to_string(), "D2".to_string()], + false + ).unwrap(); + + assert!(result.is_some()); + let separator: HashSet = result.unwrap(); + + // Should include intermediate nodes from both diamonds + assert!(!separator.is_empty()); + // Should contain some combination of B1, C1, B2, C2 + let intermediate_nodes: HashSet = vec![ + "B1".to_string(), "C1".to_string(), "B2".to_string(), "C2".to_string() + ].into_iter().collect(); + + assert_eq!(separator, intermediate_nodes); +} \ No newline at end of file diff --git a/wasm_bindings/js/tests/test-dag.js b/wasm_bindings/js/tests/test-dag.js index dffee2e..fdb0d7d 100644 --- a/wasm_bindings/js/tests/test-dag.js +++ b/wasm_bindings/js/tests/test-dag.js @@ -43,11 +43,11 @@ describe("DAG wasm (CJS)", () => { expect(areNeighbors).toBe(false); }); - it("should compute minimal d-separator (simple)", () => { + it("should compute minimal d-separator (simple)", () => { const dag = new cg.DAG(); dag.addEdge("A", "B"); dag.addEdge("B", "C"); - const sep = dag.minimalDseparator("A", "C"); + const sep = dag.minimalDseparator(["A"], ["C"]); expect(sep.sort()).toEqual(["B"]); }); @@ -58,7 +58,7 @@ describe("DAG wasm (CJS)", () => { dag.addEdge("C", "D"); dag.addEdge("A", "E"); dag.addEdge("E", "D"); - const sep = dag.minimalDseparator("A", "D"); + const sep = dag.minimalDseparator(["A"], ["D"]); expect(sep.sort()).toEqual(["C", "E"]); }); @@ -69,7 +69,7 @@ describe("DAG wasm (CJS)", () => { dag.addNode("C", false); dag.addEdge("A", "B"); dag.addEdge("B", "C"); - const sep = dag.minimalDseparator("A", "C"); + const sep = dag.minimalDseparator(["A"], ["C"]); expect(sep).toBeNull(); }); diff --git a/wasm_bindings/src/lib.rs b/wasm_bindings/src/lib.rs index e1155f2..42b2b67 100644 --- a/wasm_bindings/src/lib.rs +++ b/wasm_bindings/src/lib.rs @@ -98,8 +98,8 @@ impl DAG { // In RustDAG impl #[wasm_bindgen(js_name = minimalDseparator, catch)] - pub fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> Result { - let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + pub fn minimal_dseparator(&self, starts: Vec, ends: Vec, include_latents: Option) -> Result { + let result = self.inner.minimal_dseparator(starts, ends, include_latents.unwrap_or(false)) .map_err(|e| JsValue::from_str(&e))?; match result {