From 9e5152d72cdc2fa313b305bfe1411d658610ca9f Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sun, 3 Aug 2025 23:54:29 +0530 Subject: [PATCH 01/13] initial pdag impl --- python_bindings/.gitignore | 1 + rust_core/src/lib.rs | 3 +- rust_core/src/pdag.rs | 239 ++++++++++++++++++++++++++++++++++ rust_core/tests/pdag_tests.rs | 116 +++++++++++++++++ 4 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 rust_core/src/pdag.rs create mode 100644 rust_core/tests/pdag_tests.rs diff --git a/python_bindings/.gitignore b/python_bindings/.gitignore index c8f0442..d0d66fc 100644 --- a/python_bindings/.gitignore +++ b/python_bindings/.gitignore @@ -70,3 +70,4 @@ docs/_build/ # Pyenv .python-version +*pvt_tests* \ No newline at end of file diff --git a/rust_core/src/lib.rs b/rust_core/src/lib.rs index 77d0741..bcf918f 100644 --- a/rust_core/src/lib.rs +++ b/rust_core/src/lib.rs @@ -1,7 +1,8 @@ // Re-export modules/structs from your core logic pub mod dag; pub mod independencies; -// pub mod pdag; // Add PDAG.rs later if needed +pub mod pdag; // Add PDAG.rs later if needed pub use dag::RustDAG; +pub use pdag::RustPDAG; pub use independencies::{IndependenceAssertion, Independencies}; \ No newline at end of file diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs new file mode 100644 index 0000000..10fcdd4 --- /dev/null +++ b/rust_core/src/pdag.rs @@ -0,0 +1,239 @@ +use petgraph::Direction; +use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; +use std::collections::{HashMap, HashSet}; + +use crate::RustDAG; + + +#[derive(Debug, Clone)] +pub struct RustPDAG { + pub graph: DiGraph, + pub node_map: HashMap, + pub reverse_node_map: HashMap, + pub directed_edges: HashSet<(String, String)>, + pub undirected_edges: HashSet<(String, String)>, + pub latents: HashSet, +} +impl RustPDAG { + pub fn new() -> Self { + RustPDAG { + graph: DiGraph::new(), + node_map: HashMap::new(), + reverse_node_map: HashMap::new(), + directed_edges: HashSet::new(), + undirected_edges: HashSet::new(), + latents: HashSet::new(), + } + } + + /// Get all edges in the graph + pub fn edges(&self) -> Vec<(String, String)> { + self.graph + .edge_indices() + .map(|edge_idx| { + let (source, target) = self.graph.edge_endpoints(edge_idx).unwrap(); + ( + self.reverse_node_map[&source].clone(), + self.reverse_node_map[&target].clone(), + ) + }) + .collect() + } + + /// Get all nodes in the graph + pub fn nodes(&self) -> Vec { + let mut nodes: Vec = self.node_map.keys().cloned().collect(); + nodes.sort(); // Sort alphabetically for deterministic order + nodes + } + /// Adds a single node to the PDAG. + pub fn add_node(&mut self, node: String, latent: bool) -> Result<(), String> { + if !self.node_map.contains_key(&node) { + let idx: NodeIndex = self.graph.add_node(node.clone()); + self.node_map.insert(node.clone(), idx); + self.reverse_node_map.insert(idx, node.clone()); + + if latent { + self.latents.insert(node); + } + } + Ok(()) + } + + /// Adds multiple nodes to the PDAG. + pub fn add_nodes_from(&mut self, nodes: Vec, latent: Option>) -> Result<(), String> { + let latent_flags: Vec = latent.unwrap_or_else(|| vec![false; nodes.len()]); + + if nodes.len() != latent_flags.len() { + return Err("Length of nodes and latent flags must match".to_string()); + } + + for (node, is_latent) in nodes.iter().zip(latent_flags.iter()) { + self.add_node(node.clone(), *is_latent)?; + } + Ok(()) + } + + /// Adds a single edge (directed or undirected) to the PDAG. + pub fn add_edge(&mut self, u: String, v: String, weight: Option, directed: bool) -> Result<(), String> { + // Add nodes if they don't exist + self.add_node(u.clone(), false)?; + self.add_node(v.clone(), false)?; + + let u_idx = self.node_map[&u]; + let v_idx = self.node_map[&v]; + + if directed { + // Check for cycles before adding directed edge + let mut temp_graph = self.graph.clone(); + temp_graph.add_edge(u_idx, v_idx, weight.unwrap_or(1.0)); + if petgraph::algo::is_cyclic_directed(&temp_graph) { + return Err(format!("Adding directed edge {} -> {} creates a cycle", u, v)); + } + self.graph.add_edge(u_idx, v_idx, weight.unwrap_or(1.0)); + self.directed_edges.insert((u.clone(), v.clone())); + } else { + // Add undirected edge (bidirectional in graph) + self.graph.add_edge(u_idx, v_idx, weight.unwrap_or(1.0)); + self.graph.add_edge(v_idx, u_idx, weight.unwrap_or(1.0)); + self.undirected_edges.insert((u.clone(), v.clone())); + } + Ok(()) + } + + /// Adds multiple edges (directed or undirected) to the PDAG. + pub fn add_edges_from( + &mut self, + ebunch: Option>, + weights: Option>, + directed: bool, + ) -> Result<(), String> { + let ebunch = ebunch.unwrap_or_default(); + let weights = weights.unwrap_or_else(|| vec![1.0; ebunch.len()]); + + if ebunch.len() != weights.len() { + return Err("The number of elements in ebunch and weights should be equal".to_string()); + } + + for (i, (u, v)) in ebunch.iter().enumerate() { + self.add_edge(u.clone(), v.clone(), Some(weights[i]), directed)?; + } + Ok(()) + } + + /// Returns all neighbors (via directed or undirected edges) of a node. + pub fn all_neighbors(&self, node: &str) -> Result, String> { + let node_idx = self.node_map.get(node) + .ok_or_else(|| format!("Node {} not found", node))?; + + let successors: HashSet = self.graph + .neighbors_directed(*node_idx, Direction::Outgoing) + .map(|idx| self.reverse_node_map[&idx].clone()) + .collect(); + + let predecessors: HashSet = self.graph + .neighbors_directed(*node_idx, Direction::Incoming) + .map(|idx| self.reverse_node_map[&idx].clone()) + .collect(); + + Ok(successors.union(&predecessors).cloned().collect()) + } + + /// Returns children of a node via directed edges (node -> child). + pub fn directed_children(&self, node: &str) -> Result, String> { + let node_idx = self.node_map.get(node) + .ok_or_else(|| format!("Node {} not found", node))?; + + let children: HashSet = self.graph + .neighbors_directed(*node_idx, Direction::Outgoing) + .filter(|&idx| { + let child = &self.reverse_node_map[&idx]; + self.directed_edges.contains(&(node.to_string(), child.to_string())) + }) + .map(|idx| self.reverse_node_map[&idx].clone()) + .collect(); + + Ok(children) + } + + /// Returns parents of a node via directed edges (parent -> node). + pub fn directed_parents(&self, node: &str) -> Result, String> { + let node_idx = self.node_map.get(node) + .ok_or_else(|| format!("Node {} not found", node))?; + + let parents: HashSet = self.graph + .neighbors_directed(*node_idx, Direction::Incoming) + .filter(|&idx| { + let parent = &self.reverse_node_map[&idx]; + self.directed_edges.contains(&(parent.to_string(), node.to_string())) + }) + .map(|idx| self.reverse_node_map[&idx].clone()) + .collect(); + + Ok(parents) + } + + /// Checks if there is a directed edge u -> v. + pub fn has_directed_edge(&self, u: &str, v: &str) -> bool { + self.directed_edges.contains(&(u.to_string(), v.to_string())) + } + + /// Checks if there is an undirected edge u - v. + pub fn has_undirected_edge(&self, u: &str, v: &str) -> bool { + self.undirected_edges.contains(&(u.to_string(), v.to_string())) || + self.undirected_edges.contains(&(v.to_string(), u.to_string())) + } + + /// Returns neighbors connected via undirected edges. + pub fn undirected_neighbors(&self, node: &str) -> Result, String> { + let node_idx = self.node_map.get(node) + .ok_or_else(|| format!("Node {} not found", node))?; + + let neighbors: HashSet = self.graph + .neighbors_directed(*node_idx, Direction::Outgoing) + .filter(|&idx| { + let neighbor = &self.reverse_node_map[&idx]; + self.has_undirected_edge(node, neighbor) + }) + .map(|idx| self.reverse_node_map[&idx].clone()) + .collect(); + + Ok(neighbors) + } + + /// Checks if two nodes are adjacent (via any edge: directed or undirected). + pub fn is_adjacent(&self, u: &str, v: &str) -> bool { + self.has_directed_edge(u, v) || self.has_directed_edge(v, u) || self.has_undirected_edge(u, v) + } + + /// Returns a copy of the PDAG. + pub fn copy(&self) -> RustPDAG { + RustPDAG { + graph: self.graph.clone(), + node_map: self.node_map.clone(), + reverse_node_map: self.reverse_node_map.clone(), + directed_edges: self.directed_edges.clone(), + undirected_edges: self.undirected_edges.clone(), + latents: self.latents.clone(), + } + } + + /// Returns a subgraph containing only directed edges as a RustDAG. + pub fn directed_graph(&self) -> RustDAG { + let mut dag = RustDAG::new(); + + // Add all nodes with their latent status + for node in self.node_map.keys() { + let is_latent = self.latents.contains(node); + dag.add_node(node.clone(), is_latent).unwrap(); + } + + // Add only directed edges + for (u, v) in &self.directed_edges { + dag.add_edge(u.clone(), v.clone(), None).unwrap(); + } + + dag + } + +} diff --git a/rust_core/tests/pdag_tests.rs b/rust_core/tests/pdag_tests.rs new file mode 100644 index 0000000..c772a3d --- /dev/null +++ b/rust_core/tests/pdag_tests.rs @@ -0,0 +1,116 @@ +use std::collections::HashSet; +use rust_core::RustPDAG; + +#[test] +fn test_init_normal() { + // Test initialization with mixed edges + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("A".to_string(), "B".to_string()), + ("B".to_string(), "A".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "B".to_string()), + ].into_iter().collect(); + + // Convert pdag.edges() to HashSet for order-insensitive comparison + let actual_edges: HashSet<(String, String)> = pdag.edges().into_iter().collect(); + assert_eq!(actual_edges, expected_edges); + + let actual_nodes: HashSet = pdag.nodes().into_iter().collect(); + let expected_nodes: HashSet = vec!["A", "B", "C", "D"] + .into_iter() + .map(|s| s.to_string()) + .collect(); + + assert_eq!(actual_nodes, expected_nodes); + + assert_eq!(pdag.directed_edges, HashSet::from_iter(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())])); + assert_eq!(pdag.undirected_edges, HashSet::from_iter(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())])); + + // Test with latents + let mut pdag_latent = RustPDAG::new(); + pdag_latent.add_nodes_from(vec!["A".to_string(), "D".to_string()], Some(vec![true, true])).unwrap(); + pdag_latent.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, false).unwrap(); + + assert_eq!(pdag_latent.latents, HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); +} + +#[test] +fn test_all_neighbors() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert_eq!(pdag.all_neighbors("A").unwrap(), HashSet::from_iter(vec!["B".to_string(), "C".to_string()])); + assert_eq!(pdag.all_neighbors("B").unwrap(), HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); + assert_eq!(pdag.all_neighbors("C").unwrap(), HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); + assert_eq!(pdag.all_neighbors("D").unwrap(), HashSet::from_iter(vec!["B".to_string(), "C".to_string()])); +} + + +#[test] +fn test_directed_children() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert_eq!(pdag.directed_children("A").unwrap(), HashSet::from_iter(vec!["C".to_string()])); + assert_eq!(pdag.directed_children("B").unwrap(), HashSet::new()); + assert_eq!(pdag.directed_children("C").unwrap(), HashSet::new()); +} + +#[test] +fn test_directed_parents() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert_eq!(pdag.directed_parents("A").unwrap(), HashSet::new()); + assert_eq!(pdag.directed_parents("B").unwrap(), HashSet::new()); + + + assert_eq!(pdag.directed_parents("C").unwrap(), HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); +} + +#[test] +fn test_has_directed_edge() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert!(pdag.has_directed_edge("A", "C")); + assert!(pdag.has_directed_edge("D", "C")); + assert!(!pdag.has_directed_edge("A", "B")); + assert!(!pdag.has_directed_edge("B", "A")); +} + + +#[test] +fn test_has_undirected_edge() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert!(!pdag.has_undirected_edge("A", "C")); + assert!(!pdag.has_undirected_edge("D", "C")); + assert!(pdag.has_undirected_edge("A", "B")); + assert!(pdag.has_undirected_edge("B", "A")); +} + + +#[test] +fn test_undirected_neighbors() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), None, false).unwrap(); + + assert_eq!(pdag.undirected_neighbors("A").unwrap(), HashSet::from_iter(vec!["B".to_string()])); + assert_eq!(pdag.undirected_neighbors("B").unwrap(), HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); + assert_eq!(pdag.undirected_neighbors("C").unwrap(), HashSet::new()); + assert_eq!(pdag.undirected_neighbors("D").unwrap(), HashSet::from_iter(vec!["B".to_string()])); +} \ No newline at end of file From e88611e7fb5155d83b25e4d14545fd0dc3aaa21d Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Tue, 12 Aug 2025 00:30:44 +0530 Subject: [PATCH 02/13] * orient edges * meeks rule * pdag to dag --- rust_core/src/pdag.rs | 358 +++++++++++++++++++++++++++++++++- rust_core/tests/pdag_tests.rs | 220 +++++++++++++++++++++ 2 files changed, 572 insertions(+), 6 deletions(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index 10fcdd4..b4c9179 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -1,6 +1,7 @@ use petgraph::Direction; use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; use std::collections::{HashMap, HashSet}; +use petgraph::visit::Dfs; use crate::RustDAG; @@ -128,11 +129,13 @@ impl RustPDAG { let successors: HashSet = self.graph .neighbors_directed(*node_idx, Direction::Outgoing) + .filter(|&idx| self.reverse_node_map.contains_key(&idx)) .map(|idx| self.reverse_node_map[&idx].clone()) .collect(); let predecessors: HashSet = self.graph .neighbors_directed(*node_idx, Direction::Incoming) + .filter(|&idx| self.reverse_node_map.contains_key(&idx)) .map(|idx| self.reverse_node_map[&idx].clone()) .collect(); @@ -147,8 +150,11 @@ impl RustPDAG { let children: HashSet = self.graph .neighbors_directed(*node_idx, Direction::Outgoing) .filter(|&idx| { - let child = &self.reverse_node_map[&idx]; - self.directed_edges.contains(&(node.to_string(), child.to_string())) + if let Some(child) = self.reverse_node_map.get(&idx) { + self.directed_edges.contains(&(node.to_string(), child.to_string())) + } else { + false // Skip invalid indices + } }) .map(|idx| self.reverse_node_map[&idx].clone()) .collect(); @@ -164,8 +170,11 @@ impl RustPDAG { let parents: HashSet = self.graph .neighbors_directed(*node_idx, Direction::Incoming) .filter(|&idx| { - let parent = &self.reverse_node_map[&idx]; - self.directed_edges.contains(&(parent.to_string(), node.to_string())) + if let Some(parent) = self.reverse_node_map.get(&idx) { + self.directed_edges.contains(&(parent.to_string(), node.to_string())) + } else { + false // Skip invalid indices + } }) .map(|idx| self.reverse_node_map[&idx].clone()) .collect(); @@ -192,8 +201,11 @@ impl RustPDAG { let neighbors: HashSet = self.graph .neighbors_directed(*node_idx, Direction::Outgoing) .filter(|&idx| { - let neighbor = &self.reverse_node_map[&idx]; - self.has_undirected_edge(node, neighbor) + if let Some(neighbor) = self.reverse_node_map.get(&idx) { + self.has_undirected_edge(node, neighbor) + } else { + false // Skip invalid indices + } }) .map(|idx| self.reverse_node_map[&idx].clone()) .collect(); @@ -236,4 +248,338 @@ impl RustPDAG { dag } + /// Orient an undirected edge u - v as u -> v + pub fn orient_undirected_edge(&mut self, u: &str, v: &str, inplace: bool) -> Result, String> { + let mut pdag = if inplace { + self + } else { + &mut self.copy() + }; + + // Check if undirected edge exists + let edge_exists = if pdag.undirected_edges.contains(&(u.to_string(), v.to_string())) { + pdag.undirected_edges.remove(&(u.to_string(), v.to_string())); + true + } else if pdag.undirected_edges.contains(&(v.to_string(), u.to_string())) { + pdag.undirected_edges.remove(&(v.to_string(), u.to_string())); + true + } else { + false + }; + + if !edge_exists { + return Err(format!("Undirected Edge {} - {} not present in the PDAG", u, v)); + } + + // Remove the reverse edge from the graph + let u_idx = pdag.node_map[u]; + let v_idx = pdag.node_map[v]; + + // Find and remove the edge v -> u + if let Some(edge_idx) = pdag.graph.find_edge(v_idx, u_idx) { + pdag.graph.remove_edge(edge_idx); + } + + // Add to directed edges + pdag.directed_edges.insert((u.to_string(), v.to_string())); + + if inplace { + Ok(None) + } else { + Ok(Some(pdag.clone())) + } + } + + /// Check if orienting u -> v would create a new unshielded collider + fn check_new_unshielded_collider(&self, u: &str, v: &str) -> Result { + let parents = self.directed_parents(v)?; + + for parent in parents { + if parent != u && !self.is_adjacent(u, &parent) { + return Ok(true); + } + } + Ok(false) + } + + /// Check if there's a path from source to target in the directed subgraph + pub fn has_directed_path(&self, source: &str, target: &str) -> Result { + let source_idx = self.node_map.get(source) + .ok_or_else(|| format!("Node {} not found", source))?; + let target_idx = self.node_map.get(target) + .ok_or_else(|| format!("Node {} not found", target))?; + + let directed_graph = self.directed_graph(); + let mut dfs = Dfs::new(&directed_graph.graph, *source_idx); + + while let Some(nx) = dfs.next(&directed_graph.graph) { + if nx == *target_idx { + return Ok(true); + } + } + Ok(false) + } + + + /// Apply Meek's rules to orient undirected edges + pub fn apply_meeks_rules(&mut self, apply_r4: bool, inplace: bool) -> Result, String> { + let mut pdag = if inplace { + self + } else { + &mut self.copy() + }; + + let mut changed = true; + while changed { + changed = false; + let nodes: Vec = pdag.nodes(); + + // Rule 1: If X -> Y - Z and + // (X not adj Z) and + // (adding Y -> Z doesn't create cycle) and + // (adding Y -> Z doesn't create an unshielded collider) => Y → Z + for y in &nodes { + if !pdag.node_map.contains_key(y) { + continue; + } + let directed_parents = pdag.directed_parents(y)?; + let undirected_neighbors = pdag.undirected_neighbors(y)?; + + for x in &directed_parents { + for z in &undirected_neighbors { + if !pdag.is_adjacent(x, z) + && !pdag.check_new_unshielded_collider(y, z)? + && !pdag.has_directed_path(z, y)? + { + // Ensure x -> y exists + if pdag.has_directed_edge(x, y) { + if pdag.orient_undirected_edge(y, z, true).is_ok() { + changed = true; + break; + } + } + } + } + if changed { break; } + } + if changed { break; } + } + + // Rule 2: If X -> Z and Z -> Y and X - Y => X -> Y + for z in &nodes { + if !pdag.node_map.contains_key(z) { + continue; + } + let parents = pdag.directed_parents(z)?; + let children = pdag.directed_children(z)?; + + for x in &parents { + for y in &children { + if pdag.has_undirected_edge(x, y) { + // Ensure x -> z and z -> y exist + if pdag.has_directed_edge(x, z) && pdag.has_directed_edge(z, y) { + if pdag.orient_undirected_edge(x, y, true).is_ok() { + changed = true; + break; + } + } + } + } + if changed { break; } + } + if changed { break; } + } + + // Rule 3 + for x in &nodes { + if !pdag.node_map.contains_key(x) { + continue; + } + let undirected_nbs: Vec = pdag.undirected_neighbors(x)?.into_iter().collect(); + + if undirected_nbs.len() < 3 { + continue; + } + + for i in 0..undirected_nbs.len() { + for j in (i + 1)..undirected_nbs.len() { + for k in (j + 1)..undirected_nbs.len() { + let (y, z, w) = (&undirected_nbs[i], &undirected_nbs[j], &undirected_nbs[k]); + + if pdag.has_directed_edge(y, w) && pdag.has_directed_edge(z, w) { + if pdag.orient_undirected_edge(x, w, true).is_ok() { + changed = true; + break; + } + } + } + if changed { break; } + } + if changed { break; } + } + if changed { break; } + } + + // Rule 4 + if apply_r4 { + for c in &nodes { + if !pdag.node_map.contains_key(c) { + continue; + } + let children = pdag.directed_children(c)?; + let parents = pdag.directed_parents(c)?; + + for b in &children { + for d in &parents { + if b == d || pdag.is_adjacent(b, d) { + continue; + } + + let b_undirected = pdag.undirected_neighbors(b)?; + let c_neighbors = pdag.all_neighbors(c)?; + let d_undirected = pdag.undirected_neighbors(d)?; + + for a in &b_undirected { + if c_neighbors.contains(a) && d_undirected.contains(a) { + if pdag.orient_undirected_edge(a, b, true).is_ok() { + changed = true; + break; + } + } + } + if changed { break; } + } + if changed { break; } + } + if changed { break; } + } + } + } + + if inplace { + Ok(None) + } else { + Ok(Some(pdag.clone())) + } + } + + + pub fn to_dag(&self) -> Result { + let mut dag = RustDAG::new(); + + // Add all nodes with latent status + for node in self.nodes() { + let is_latent = self.latents.contains(&node); + dag.add_node(node.clone(), is_latent)?; + } + + // Add all directed edges + for (u, v) in &self.directed_edges { + dag.add_edge(u.clone(), v.clone(), None)?; + } + + let mut pdag_copy = self.copy(); + + // Add undirected edges to dag before node removal + for (u, v) in &self.undirected_edges { + if !dag.has_edge(u, v) && !dag.has_edge(v, u) { + // Try adding u -> v, if it creates cycle, add v -> u + if dag.add_edge(u.clone(), v.clone(), None).is_err() { + dag.add_edge(v.clone(), u.clone(), None)?; + } + } + } + + while !pdag_copy.nodes().is_empty() { + let nodes: Vec = pdag_copy.nodes(); // Get fresh node list + let mut found = false; + + for x in &nodes { + // Check if node still exists + if !pdag_copy.node_map.contains_key(x) { + continue; + } + + // Find nodes with no directed outgoing edges + let directed_children = pdag_copy.directed_children(x)?; + let undirected_neighbors = pdag_copy.undirected_neighbors(x)?; + let directed_parents = pdag_copy.directed_parents(x)?; + + // Check if undirected neighbors + parents form a clique + let mut neighbors_are_clique = true; + for y in &undirected_neighbors { + for z in &directed_parents { + if y != z && !pdag_copy.is_adjacent(y, z) { + neighbors_are_clique = false; + break; + } + } + if !neighbors_are_clique { break; } + } + + if directed_children.is_empty() && (undirected_neighbors.is_empty() || neighbors_are_clique) { + found = true; + + // Add all incoming edges to DAG + let all_predecessors = pdag_copy.all_neighbors(x)?; + for y in &all_predecessors { + if pdag_copy.is_adjacent(y, x) && !dag.has_edge(x, y) { + dag.add_edge(y.clone(), x.clone(), None)?; + } + } + + // Remove node from pdag_copy + pdag_copy.remove_node(x)?; + break; // Break to refresh node list + } + } + + if !found { + // Handle remaining edges arbitrarily, ensuring no cycles + let remaining_edges: Vec<(String, String)> = pdag_copy.undirected_edges.iter().cloned().collect(); + for (u, v) in remaining_edges { + if pdag_copy.node_map.contains_key(&u) && pdag_copy.node_map.contains_key(&v) && !dag.has_edge(&v, &u) { + if let Ok(()) = dag.add_edge(u.clone(), v.clone(), None) { + pdag_copy.orient_undirected_edge(&u, &v, true)?; + } else { + // Try reverse direction if adding u -> v creates a cycle + if !dag.has_edge(&u, &v) { + if let Ok(()) = dag.add_edge(v.clone(), u.clone(), None) { + pdag_copy.orient_undirected_edge(&v, &u, true)?; + } + } + } + } + } + break; + } + } + + Ok(dag) + } + + /// Remove a node from the PDAG + fn remove_node(&mut self, node: &str) -> Result<(), String> { + let node_idx = self.node_map.get(node) + .ok_or_else(|| format!("Node {} not found", node))?; + + // Remove from edge sets + self.directed_edges.retain(|(u, v)| u != node && v != node); + self.undirected_edges.retain(|(u, v)| u != node && v != node); + + // Remove from latents + self.latents.remove(node); + + // Remove from graph + self.graph.remove_node(*node_idx); + + // Remove from mappings + self.reverse_node_map.remove(node_idx); + self.node_map.remove(node); + + Ok(()) + } + + + } diff --git a/rust_core/tests/pdag_tests.rs b/rust_core/tests/pdag_tests.rs index c772a3d..1884844 100644 --- a/rust_core/tests/pdag_tests.rs +++ b/rust_core/tests/pdag_tests.rs @@ -113,4 +113,224 @@ fn test_undirected_neighbors() { assert_eq!(pdag.undirected_neighbors("B").unwrap(), HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); assert_eq!(pdag.undirected_neighbors("C").unwrap(), HashSet::new()); assert_eq!(pdag.undirected_neighbors("D").unwrap(), HashSet::from_iter(vec!["B".to_string()])); +} + +#[test] +fn test_is_adjacent() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), + None, + true + ).unwrap(); + pdag.add_edges_from( + Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), + None, + false + ).unwrap(); + + assert!(pdag.is_adjacent("A", "B")); + assert!(pdag.is_adjacent("B", "A")); + assert!(pdag.is_adjacent("A", "C")); + assert!(pdag.is_adjacent("C", "A")); + assert!(pdag.is_adjacent("D", "C")); + assert!(pdag.is_adjacent("C", "D")); + assert!(!pdag.is_adjacent("A", "D")); + assert!(!pdag.is_adjacent("B", "C")); +} + + +#[test] +fn test_orient_undirected_edge() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), + None, + true + ).unwrap(); + pdag.add_edges_from( + Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), + None, + false + ).unwrap(); + + let mod_pdag = pdag.orient_undirected_edge("B", "A", false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("B".to_string(), "A".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "B".to_string()), + ].into_iter().collect(); + + let actual_edges: HashSet<(String, String)> = mod_pdag.edges().into_iter().collect(); + assert_eq!(actual_edges, expected_edges); + assert_eq!(mod_pdag.undirected_edges, HashSet::from_iter(vec![("B".to_string(), "D".to_string())])); + assert_eq!(mod_pdag.directed_edges, HashSet::from_iter(vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("B".to_string(), "A".to_string()) + ])); + // Test inplace modification + pdag.orient_undirected_edge("B", "A", true).unwrap(); + let expected_edges_inplace: HashSet<(String, String)> = vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("B".to_string(), "A".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "B".to_string()), + ].into_iter().collect(); + + let actual_edges_inplace: HashSet<(String, String)> = pdag.edges().into_iter().collect(); + assert_eq!(actual_edges_inplace, expected_edges_inplace); + assert_eq!(pdag.undirected_edges, HashSet::from_iter(vec![("B".to_string(), "D".to_string())])); + assert_eq!(pdag.directed_edges, HashSet::from_iter(vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("B".to_string(), "A".to_string()) + ])); + // Test error case - edge doesn't exist + assert!(pdag.orient_undirected_edge("A", "C", true).is_err()); +} + + +#[test] +fn test_copy() { + // Test copy with mixed edges + let mut pdag_mix = RustPDAG::new(); + pdag_mix.add_edges_from( + Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), + None, + true + ).unwrap(); + pdag_mix.add_edges_from( + Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), + None, + false + ).unwrap(); + let pdag_copy = pdag_mix.copy(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ("A".to_string(), "B".to_string()), + ("B".to_string(), "A".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "B".to_string()), + ].into_iter().collect(); + + let actual_edges: HashSet<(String, String)> = pdag_copy.edges().into_iter().collect(); + assert_eq!(actual_edges, expected_edges); + assert_eq!(pdag_copy.nodes().into_iter().collect::>(), + HashSet::from_iter(vec!["A".to_string(), "B".to_string(), "C".to_string(), "D".to_string()])); + assert_eq!(pdag_copy.directed_edges, HashSet::from_iter(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())])); + assert_eq!(pdag_copy.undirected_edges, HashSet::from_iter(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())])); + assert_eq!(pdag_copy.latents, HashSet::new()); + // Test copy with latents + let mut pdag_latent = RustPDAG::new(); + pdag_latent.add_edges_from( + Some(vec![("A".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), + None, + true + ).unwrap(); + pdag_latent.add_edges_from( + Some(vec![("B".to_string(), "A".to_string()), ("B".to_string(), "D".to_string())]), + None, + false + ).unwrap(); + pdag_latent.latents.insert("A".to_string()); + pdag_latent.latents.insert("D".to_string()); + let pdag_copy_latent = pdag_latent.copy(); + assert_eq!(pdag_copy_latent.latents, HashSet::from_iter(vec!["A".to_string(), "D".to_string()])); +} + +#[test] +fn test_apply_meeks_rules_basic() { + // Test Rule 1: A -> B - C and A not adjacent to C => B -> C + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + assert!(cpdag.has_directed_edge("A", "B")); + assert!(cpdag.has_directed_edge("B", "C")); + assert_eq!(cpdag.edges().into_iter().collect::>(), + HashSet::from_iter(vec![("A".to_string(), "B".to_string()), ("B".to_string(), "C".to_string())])); +} + +#[test] +fn test_apply_meeks_rules_rule2() { + // Test Rule 2: A -> B -> C and A - C => A -> C + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "C".to_string()), ("C".to_string(), "D".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ].into_iter().collect(); + + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + +#[test] +fn test_apply_meeks_rules_no_change() { + // Test case where no rules apply + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "B".to_string()), + ("D".to_string(), "C".to_string()), + ("B".to_string(), "C".to_string()), + ("C".to_string(), "B".to_string()), + ].into_iter().collect(); + + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + +#[test] +fn test_apply_meeks_rules_inplace() { + // Test inplace modification + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "C".to_string())]), None, false).unwrap(); + + pdag.apply_meeks_rules(true, true).unwrap(); + assert!(pdag.has_directed_edge("A", "B")); + assert!(pdag.has_directed_edge("B", "C")); + assert_eq!(pdag.edges().into_iter().collect::>(), + HashSet::from_iter(vec![("A".to_string(), "B".to_string()), ("B".to_string(), "C".to_string())])); +} + + + +#[test] +fn test_to_dag_basic() { + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![("A".to_string(), "C".to_string()), ("C".to_string(), "B".to_string())]), + None, + true + ).unwrap(); + pdag.add_edges_from( + Some(vec![("C".to_string(), "D".to_string()), ("D".to_string(), "A".to_string())]), + None, + false + ).unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + + // Expected edges: A -> C, C -> B, and either C -> D, A -> D or D -> C, D -> A + assert_eq!(dag_edges.len(), 4); + assert!(dag.has_edge("A", "C")); + assert!(dag.has_edge("C", "B")); + assert!(!(dag.has_edge("A", "D") && dag.has_edge("C", "D"))); // No V-structure + assert!(dag_edges.contains(&("C".to_string(), "D".to_string())) || dag_edges.contains(&("D".to_string(), "C".to_string()))); + assert!(dag_edges.contains(&("D".to_string(), "A".to_string())) || dag_edges.contains(&("A".to_string(), "D".to_string()))); } \ No newline at end of file From a66c6f446082b13b3ebc49879848c4c4d30e0d32 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Wed, 13 Aug 2025 00:18:27 +0530 Subject: [PATCH 03/13] deterministic sorting --- rust_core/src/pdag.rs | 158 +++++++++++++++++++++++++----------------- 1 file changed, 93 insertions(+), 65 deletions(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index b4c9179..c917fa6 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -5,7 +5,6 @@ use petgraph::visit::Dfs; use crate::RustDAG; - #[derive(Debug, Clone)] pub struct RustPDAG { pub graph: DiGraph, @@ -15,6 +14,7 @@ pub struct RustPDAG { pub undirected_edges: HashSet<(String, String)>, pub latents: HashSet, } + impl RustPDAG { pub fn new() -> Self { RustPDAG { @@ -27,9 +27,9 @@ impl RustPDAG { } } - /// Get all edges in the graph + /// Get all edges in the graph - DETERMINISTIC pub fn edges(&self) -> Vec<(String, String)> { - self.graph + let mut edges: Vec<(String, String)> = self.graph .edge_indices() .map(|edge_idx| { let (source, target) = self.graph.edge_endpoints(edge_idx).unwrap(); @@ -38,15 +38,18 @@ impl RustPDAG { self.reverse_node_map[&target].clone(), ) }) - .collect() + .collect(); + edges.sort(); + edges } /// Get all nodes in the graph pub fn nodes(&self) -> Vec { let mut nodes: Vec = self.node_map.keys().cloned().collect(); - nodes.sort(); // Sort alphabetically for deterministic order + nodes.sort(); nodes } + /// Adds a single node to the PDAG. pub fn add_node(&mut self, node: String, latent: bool) -> Result<(), String> { if !self.node_map.contains_key(&node) { @@ -230,19 +233,23 @@ impl RustPDAG { } } - /// Returns a subgraph containing only directed edges as a RustDAG. + /// Returns a subgraph containing only directed edges as a RustDAG - DETERMINISTIC pub fn directed_graph(&self) -> RustDAG { let mut dag = RustDAG::new(); - // Add all nodes with their latent status - for node in self.node_map.keys() { - let is_latent = self.latents.contains(node); + // Add all nodes with their latent status - DETERMINISTIC ORDER + let mut nodes: Vec = self.node_map.keys().cloned().collect(); + nodes.sort(); + for node in nodes { + let is_latent = self.latents.contains(&node); dag.add_node(node.clone(), is_latent).unwrap(); } // Add only directed edges - for (u, v) in &self.directed_edges { - dag.add_edge(u.clone(), v.clone(), None).unwrap(); + let mut directed_edges: Vec<(String, String)> = self.directed_edges.iter().cloned().collect(); + directed_edges.sort(); + for (u, v) in directed_edges { + dag.add_edge(u, v, None).unwrap(); } dag @@ -320,40 +327,50 @@ impl RustPDAG { Ok(false) } - /// Apply Meek's rules to orient undirected edges pub fn apply_meeks_rules(&mut self, apply_r4: bool, inplace: bool) -> Result, String> { - let mut pdag = if inplace { - self + if inplace { + // Work directly on self + self.apply_meeks_rules_internal(apply_r4)?; + Ok(None) } else { - &mut self.copy() - }; + // Work on a copy + let mut pdag_copy = self.copy(); + pdag_copy.apply_meeks_rules_internal(apply_r4)?; + Ok(Some(pdag_copy)) + } + } + /// Internal method that applies Meek's rules to the current instance + fn apply_meeks_rules_internal(&mut self, apply_r4: bool) -> Result<(), String> { let mut changed = true; while changed { changed = false; - let nodes: Vec = pdag.nodes(); + let nodes: Vec = self.nodes(); // Rule 1: If X -> Y - Z and // (X not adj Z) and // (adding Y -> Z doesn't create cycle) and // (adding Y -> Z doesn't create an unshielded collider) => Y → Z for y in &nodes { - if !pdag.node_map.contains_key(y) { + if !self.node_map.contains_key(y) { continue; } - let directed_parents = pdag.directed_parents(y)?; - let undirected_neighbors = pdag.undirected_neighbors(y)?; + // Convert HashSets to sorted vectors for deterministic iteration + let mut directed_parents: Vec = self.directed_parents(y)?.into_iter().collect(); + directed_parents.sort(); + let mut undirected_neighbors: Vec = self.undirected_neighbors(y)?.into_iter().collect(); + undirected_neighbors.sort(); for x in &directed_parents { for z in &undirected_neighbors { - if !pdag.is_adjacent(x, z) - && !pdag.check_new_unshielded_collider(y, z)? - && !pdag.has_directed_path(z, y)? + if !self.is_adjacent(x, z) + && !self.check_new_unshielded_collider(y, z)? + && !self.has_directed_path(z, y)? { // Ensure x -> y exists - if pdag.has_directed_edge(x, y) { - if pdag.orient_undirected_edge(y, z, true).is_ok() { + if self.has_directed_edge(x, y) { + if self.orient_undirected_edge(y, z, true).is_ok() { changed = true; break; } @@ -367,18 +384,21 @@ impl RustPDAG { // Rule 2: If X -> Z and Z -> Y and X - Y => X -> Y for z in &nodes { - if !pdag.node_map.contains_key(z) { + if !self.node_map.contains_key(z) { continue; } - let parents = pdag.directed_parents(z)?; - let children = pdag.directed_children(z)?; + // Convert HashSets to sorted vectors for deterministic iteration + let mut parents: Vec = self.directed_parents(z)?.into_iter().collect(); + parents.sort(); + let mut children: Vec = self.directed_children(z)?.into_iter().collect(); + children.sort(); for x in &parents { for y in &children { - if pdag.has_undirected_edge(x, y) { + if self.has_undirected_edge(x, y) { // Ensure x -> z and z -> y exist - if pdag.has_directed_edge(x, z) && pdag.has_directed_edge(z, y) { - if pdag.orient_undirected_edge(x, y, true).is_ok() { + if self.has_directed_edge(x, z) && self.has_directed_edge(z, y) { + if self.orient_undirected_edge(x, y, true).is_ok() { changed = true; break; } @@ -390,12 +410,14 @@ impl RustPDAG { if changed { break; } } - // Rule 3 + // Rule 3: If X - Y, X - Z, X - W and Y -> W, Z -> W => X -> W for x in &nodes { - if !pdag.node_map.contains_key(x) { + if !self.node_map.contains_key(x) { continue; } - let undirected_nbs: Vec = pdag.undirected_neighbors(x)?.into_iter().collect(); + // Convert HashSet to sorted vector for deterministic iteration + let mut undirected_nbs: Vec = self.undirected_neighbors(x)?.into_iter().collect(); + undirected_nbs.sort(); if undirected_nbs.len() < 3 { continue; @@ -406,8 +428,8 @@ impl RustPDAG { for k in (j + 1)..undirected_nbs.len() { let (y, z, w) = (&undirected_nbs[i], &undirected_nbs[j], &undirected_nbs[k]); - if pdag.has_directed_edge(y, w) && pdag.has_directed_edge(z, w) { - if pdag.orient_undirected_edge(x, w, true).is_ok() { + if self.has_directed_edge(y, w) && self.has_directed_edge(z, w) { + if self.orient_undirected_edge(x, w, true).is_ok() { changed = true; break; } @@ -423,25 +445,31 @@ impl RustPDAG { // Rule 4 if apply_r4 { for c in &nodes { - if !pdag.node_map.contains_key(c) { + if !self.node_map.contains_key(c) { continue; } - let children = pdag.directed_children(c)?; - let parents = pdag.directed_parents(c)?; + + let mut children: Vec = self.directed_children(c)?.into_iter().collect(); + children.sort(); + let mut parents: Vec = self.directed_parents(c)?.into_iter().collect(); + parents.sort(); for b in &children { for d in &parents { - if b == d || pdag.is_adjacent(b, d) { + if b == d || self.is_adjacent(b, d) { continue; } - let b_undirected = pdag.undirected_neighbors(b)?; - let c_neighbors = pdag.all_neighbors(c)?; - let d_undirected = pdag.undirected_neighbors(d)?; + let mut b_undirected: Vec = self.undirected_neighbors(b)?.into_iter().collect(); + b_undirected.sort(); + let mut c_neighbors: Vec = self.all_neighbors(c)?.into_iter().collect(); + c_neighbors.sort(); + let mut d_undirected: Vec = self.undirected_neighbors(d)?.into_iter().collect(); + d_undirected.sort(); for a in &b_undirected { if c_neighbors.contains(a) && d_undirected.contains(a) { - if pdag.orient_undirected_edge(a, b, true).is_ok() { + if self.orient_undirected_edge(a, b, true).is_ok() { changed = true; break; } @@ -456,14 +484,9 @@ impl RustPDAG { } } - if inplace { - Ok(None) - } else { - Ok(Some(pdag.clone())) - } + Ok(()) } - pub fn to_dag(&self) -> Result { let mut dag = RustDAG::new(); @@ -473,25 +496,29 @@ impl RustPDAG { dag.add_node(node.clone(), is_latent)?; } - // Add all directed edges - for (u, v) in &self.directed_edges { - dag.add_edge(u.clone(), v.clone(), None)?; + // Add all directed edg + let mut directed_edges_sorted: Vec<(String, String)> = self.directed_edges.iter().cloned().collect(); + directed_edges_sorted.sort(); + for (u, v) in directed_edges_sorted { + dag.add_edge(u, v, None)?; } let mut pdag_copy = self.copy(); // Add undirected edges to dag before node removal - for (u, v) in &self.undirected_edges { - if !dag.has_edge(u, v) && !dag.has_edge(v, u) { + let mut undirected_edges_sorted: Vec<(String, String)> = self.undirected_edges.iter().cloned().collect(); + undirected_edges_sorted.sort(); + for (u, v) in undirected_edges_sorted { + if !dag.has_edge(&u, &v) && !dag.has_edge(&v, &u) { // Try adding u -> v, if it creates cycle, add v -> u if dag.add_edge(u.clone(), v.clone(), None).is_err() { - dag.add_edge(v.clone(), u.clone(), None)?; + dag.add_edge(v, u, None)?; } } } while !pdag_copy.nodes().is_empty() { - let nodes: Vec = pdag_copy.nodes(); // Get fresh node list + let nodes: Vec = pdag_copy.nodes(); let mut found = false; for x in &nodes { @@ -502,8 +529,10 @@ impl RustPDAG { // Find nodes with no directed outgoing edges let directed_children = pdag_copy.directed_children(x)?; - let undirected_neighbors = pdag_copy.undirected_neighbors(x)?; - let directed_parents = pdag_copy.directed_parents(x)?; + let mut undirected_neighbors: Vec = pdag_copy.undirected_neighbors(x)?.into_iter().collect(); + undirected_neighbors.sort(); + let mut directed_parents: Vec = pdag_copy.directed_parents(x)?.into_iter().collect(); + directed_parents.sort(); // Check if undirected neighbors + parents form a clique let mut neighbors_are_clique = true; @@ -521,7 +550,8 @@ impl RustPDAG { found = true; // Add all incoming edges to DAG - let all_predecessors = pdag_copy.all_neighbors(x)?; + let mut all_predecessors: Vec = pdag_copy.all_neighbors(x)?.into_iter().collect(); + all_predecessors.sort(); for y in &all_predecessors { if pdag_copy.is_adjacent(y, x) && !dag.has_edge(x, y) { dag.add_edge(y.clone(), x.clone(), None)?; @@ -536,7 +566,8 @@ impl RustPDAG { if !found { // Handle remaining edges arbitrarily, ensuring no cycles - let remaining_edges: Vec<(String, String)> = pdag_copy.undirected_edges.iter().cloned().collect(); + let mut remaining_edges: Vec<(String, String)> = pdag_copy.undirected_edges.iter().cloned().collect(); + remaining_edges.sort(); // Deterministic order for (u, v) in remaining_edges { if pdag_copy.node_map.contains_key(&u) && pdag_copy.node_map.contains_key(&v) && !dag.has_edge(&v, &u) { if let Ok(()) = dag.add_edge(u.clone(), v.clone(), None) { @@ -579,7 +610,4 @@ impl RustPDAG { Ok(()) } - - - -} +} \ No newline at end of file From fadeaf6a3b168abfcac0b2f7f096e003413d1f26 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 12:55:48 +0530 Subject: [PATCH 04/13] initial commit --- python_bindings/src/lib.rs | 168 ++++++++++++++- python_bindings/tests/test_pdag.py | 323 +++++++++++++++++++++++++++++ 2 files changed, 487 insertions(+), 4 deletions(-) create mode 100644 python_bindings/tests/test_pdag.py diff --git a/python_bindings/src/lib.rs b/python_bindings/src/lib.rs index ce5d2ea..0e352f0 100644 --- a/python_bindings/src/lib.rs +++ b/python_bindings/src/lib.rs @@ -1,7 +1,165 @@ use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use rust_core::{IndependenceAssertion, Independencies, RustDAG}; -use std::collections::HashSet; +use rust_core::{IndependenceAssertion, Independencies, RustDAG, RustPDAG}; +use std::collections::{HashMap, HashSet}; + +#[pyclass(name = "PDAG")] +#[derive(Clone)] +pub struct PyRustPDAG { + inner: RustPDAG, +} + +#[pymethods] +impl PyRustPDAG { + #[new] + pub fn new() -> Self { + PyRustPDAG { + inner: RustPDAG::new(), + } + } + + /// Add a single node to the PDAG. + pub fn add_node(&mut self, node: String, latent: Option) -> PyResult<()> { + self.inner + .add_node(node, latent.unwrap_or(false)) + .map_err(PyValueError::new_err) + } + + /// Add multiple nodes to the PDAG. + pub fn add_nodes_from(&mut self, nodes: Vec, latent: Option>) -> PyResult<()> { + self.inner + .add_nodes_from(nodes, latent) + .map_err(PyValueError::new_err) + } + + /// Add a single edge (directed or undirected) to the PDAG. + #[pyo3(signature = (u, v, weight = None, directed = true))] + pub fn add_edge(&mut self, u: String, v: String, weight: Option, directed: bool) -> PyResult<()> { + self.inner + .add_edge(u, v, weight, directed) + .map_err(PyValueError::new_err) + } + + /// Add multiple edges (directed or undirected) to the PDAG. + #[pyo3(signature = (ebunch, weights = None, directed = true))] + pub fn add_edges_from( + &mut self, + ebunch: Vec<(String, String)>, + weights: Option>, + directed: bool, + ) -> PyResult<()> { + self.inner + .add_edges_from(Some(ebunch), weights, directed) + .map_err(PyValueError::new_err) + } + + /// Get all neighbors (via directed or undirected edges) of a node. + pub fn all_neighbors(&self, node: String) -> PyResult> { + let neighbors = self.inner + .all_neighbors(&node) + .map_err(PyKeyError::new_err)?; + let mut result: Vec = neighbors.into_iter().collect(); + result.sort(); // Ensure deterministic order + Ok(result) + } + + /// Get children of a node via directed edges. + pub fn directed_children(&self, node: String) -> PyResult> { + let children = self.inner + .directed_children(&node) + .map_err(PyKeyError::new_err)?; + let mut result: Vec = children.into_iter().collect(); + result.sort(); + Ok(result) + } + + /// Get parents of a node via directed edges. + pub fn directed_parents(&self, node: String) -> PyResult> { + let parents = self.inner + .directed_parents(&node) + .map_err(PyKeyError::new_err)?; + let mut result: Vec = parents.into_iter().collect(); + result.sort(); + Ok(result) + } + + /// Check if there is a directed edge u -> v. + pub fn has_directed_edge(&self, u: String, v: String) -> bool { + self.inner.has_directed_edge(&u, &v) + } + + /// Check if there is an undirected edge u - v. + pub fn has_undirected_edge(&self, u: String, v: String) -> bool { + self.inner.has_undirected_edge(&u, &v) + } + + /// Get neighbors connected via undirected edges. + pub fn undirected_neighbors(&self, node: String) -> PyResult> { + let neighbors = self.inner + .undirected_neighbors(&node) + .map_err(PyKeyError::new_err)?; + let mut result: Vec = neighbors.into_iter().collect(); + result.sort(); + Ok(result) + } + + /// Check if two nodes are adjacent (via any edge). + pub fn is_adjacent(&self, u: String, v: String) -> bool { + self.inner.is_adjacent(&u, &v) + } + + /// Get all nodes in the PDAG. + pub fn nodes(&self) -> Vec { + self.inner.nodes() + } + + /// Get all edges in the PDAG. + pub fn edges(&self) -> Vec<(String, String)> { + self.inner.edges() + } + + /// Apply Meek's rules to orient undirected edges. + #[pyo3(signature = (apply_r4 = true, inplace = true))] + pub fn apply_meeks_rules(&mut self, apply_r4: bool, inplace: bool) -> PyResult> { + self.inner + .apply_meeks_rules(apply_r4, inplace) + .map(|opt| opt.map(|pdag| PyRustPDAG { inner: pdag })) + .map_err(PyValueError::new_err) + } + + /// Convert the PDAG to a DAG. + pub fn to_dag(&self) -> PyResult { + self.inner + .to_dag() + .map(|dag| PyRustDAG { inner: dag }) + .map_err(PyValueError::new_err) + } + + /// Check if there is a directed path from source to target. + pub fn has_directed_path(&self, source: String, target: String) -> PyResult { + self.inner + .has_directed_path(&source, &target) + .map_err(PyKeyError::new_err) + } + + /// Get the number of nodes in the PDAG. + pub fn node_count(&self) -> usize { + self.inner.node_map.len() + } + + /// Get the number of edges in the PDAG (counting undirected edges once). + pub fn edge_count(&self) -> usize { + self.inner.directed_edges.len() + self.inner.undirected_edges.len() + } + + /// Get the set of latent nodes. + #[getter] + pub fn latents(&self) -> Vec { + let mut result: Vec = self.inner.latents.iter().cloned().collect(); + result.sort(); + result + } +} #[pyclass(name = "DAG")] #[derive(Clone)] @@ -131,6 +289,7 @@ impl PyRustDAG { } } +// Existing PyIndependenceAssertion and PyIndependencies (unchanged, included for completeness) #[pyclass(name = "IndependenceAssertion")] #[derive(Clone)] pub struct PyIndependenceAssertion { @@ -155,7 +314,7 @@ impl PyIndependenceAssertion { #[getter] pub fn event1(&self) -> Vec { let mut result: Vec = self.inner.event1.iter().cloned().collect(); - result.sort(); // Ensure deterministic order + result.sort(); result } @@ -292,7 +451,8 @@ impl PyIndependencies { #[pymodule] fn causalgraphs(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) -} +} \ No newline at end of file diff --git a/python_bindings/tests/test_pdag.py b/python_bindings/tests/test_pdag.py new file mode 100644 index 0000000..0876410 --- /dev/null +++ b/python_bindings/tests/test_pdag.py @@ -0,0 +1,323 @@ +import unittest +from causalgraphs import PDAG, DAG + +class TestPDAG(unittest.TestCase): + def setUp(self): + # PDAG with mixed directed and undirected edges + self.pdag_mix = PDAG() + self.pdag_mix.add_edges_from([("A", "C"), ("D", "C")], directed=True) + self.pdag_mix.add_edges_from([("B", "A"), ("B", "D")], directed=False) + + # PDAG with only directed edges + self.pdag_dir = PDAG() + self.pdag_dir.add_edges_from([("A", "B"), ("D", "B"), ("A", "C"), ("D", "C")], directed=True) + + # PDAG with only undirected edges + self.pdag_undir = PDAG() + self.pdag_undir.add_edges_from([("A", "C"), ("D", "C"), ("B", "A"), ("B", "D")], directed=False) + + # PDAG with latents + self.pdag_latent = PDAG() + self.pdag_latent.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, False, True]) + self.pdag_latent.add_edges_from([("A", "C"), ("D", "C")], directed=True) + self.pdag_latent.add_edges_from([("B", "A"), ("B", "D")], directed=False) + + def test_init_normal(self): + # Mix directed and undirected + pdag = PDAG() + directed_edges = [("A", "C"), ("D", "C")] + undirected_edges = [("B", "A"), ("B", "D")] + pdag.add_edges_from(directed_edges, directed=True) + pdag.add_edges_from(undirected_edges, directed=False) + expected_edges = {("A", "C"), ("D", "C"), ("A", "B"), ("B", "A"), ("B", "D"), ("D", "B")} + self.assertEqual(set(pdag.edges()), expected_edges) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set(directed_edges)) + self.assertEqual(set(pdag.undirected_edges()), set(undirected_edges)) + + # Mix with latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, True, False]) + pdag.add_edges_from(directed_edges, directed=True) + pdag.add_edges_from(undirected_edges, directed=False) + self.assertEqual(set(pdag.edges()), expected_edges) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set(directed_edges)) + self.assertEqual(set(pdag.undirected_edges()), set(undirected_edges)) + self.assertEqual(set(pdag.latents), {"A", "C"}) + + # Only undirected + pdag = PDAG() + undirected_edges = [("A", "C"), ("D", "C"), ("B", "A"), ("B", "D")] + pdag.add_edges_from(undirected_edges, directed=False) + expected_edges = {("A", "C"), ("C", "A"), ("D", "C"), ("C", "D"), ("B", "A"), ("A", "B"), ("B", "D"), ("D", "B")} + self.assertEqual(set(pdag.edges()), expected_edges) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set()) + self.assertEqual(set(pdag.undirected_edges()), set(undirected_edges)) + + # Only undirected with latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, False, True]) + pdag.add_edges_from(undirected_edges, directed=False) + self.assertEqual(set(pdag.edges()), expected_edges) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set()) + self.assertEqual(set(pdag.undirected_edges()), set(undirected_edges)) + self.assertEqual(set(pdag.latents), {"A", "D"}) + + # Only directed + pdag = PDAG() + directed_edges = [("A", "B"), ("D", "B"), ("A", "C"), ("D", "C")] + pdag.add_edges_from(directed_edges, directed=True) + self.assertEqual(set(pdag.edges()), set(directed_edges)) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set(directed_edges)) + self.assertEqual(set(pdag.undirected_edges()), set()) + + # Only directed with latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[False, False, False, True]) + pdag.add_edges_from(directed_edges, directed=True) + self.assertEqual(set(pdag.edges()), set(directed_edges)) + self.assertEqual(set(pdag.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag.directed_edges()), set(directed_edges)) + self.assertEqual(set(pdag.undirected_edges()), set()) + self.assertEqual(set(pdag.latents), {"D"}) + + def test_all_neighbors(self): + pdag = self.pdag_mix + self.assertEqual(set(pdag.all_neighbors("A")), {"B", "C"}) + self.assertEqual(set(pdag.all_neighbors("B")), {"A", "D"}) + self.assertEqual(set(pdag.all_neighbors("C")), {"A", "D"}) + self.assertEqual(set(pdag.all_neighbors("D")), {"B", "C"}) + + def test_directed_children(self): + pdag = self.pdag_mix + self.assertEqual(set(pdag.directed_children("A")), {"C"}) + self.assertEqual(set(pdag.directed_children("B")), set()) + self.assertEqual(set(pdag.directed_children("C")), set()) + self.assertEqual(set(pdag.directed_children("D")), {"C"}) + + def test_directed_parents(self): + pdag = self.pdag_mix + self.assertEqual(set(pdag.directed_parents("A")), set()) + self.assertEqual(set(pdag.directed_parents("B")), set()) + self.assertEqual(set(pdag.directed_parents("C")), {"A", "D"}) + self.assertEqual(set(pdag.directed_parents("D")), set()) + + def test_has_directed_edge(self): + pdag = self.pdag_mix + self.assertTrue(pdag.has_directed_edge("A", "C")) + self.assertTrue(pdag.has_directed_edge("D", "C")) + self.assertFalse(pdag.has_directed_edge("A", "B")) + self.assertFalse(pdag.has_directed_edge("B", "A")) + + def test_has_undirected_edge(self): + pdag = self.pdag_mix + self.assertFalse(pdag.has_undirected_edge("A", "C")) + self.assertFalse(pdag.has_undirected_edge("D", "C")) + self.assertTrue(pdag.has_undirected_edge("A", "B")) + self.assertTrue(pdag.has_undirected_edge("B", "A")) + self.assertTrue(pdag.has_undirected_edge("B", "D")) + + def test_undirected_neighbors(self): + pdag = self.pdag_mix + self.assertEqual(set(pdag.undirected_neighbors("A")), {"B"}) + self.assertEqual(set(pdag.undirected_neighbors("B")), {"A", "D"}) + self.assertEqual(set(pdag.undirected_neighbors("C")), set()) + self.assertEqual(set(pdag.undirected_neighbors("D")), {"B"}) + + def test_orient_undirected_edge(self): + pdag = self.pdag_mix.copy() + mod_pdag = pdag.orient_undirected_edge("B", "A", inplace=False) + self.assertEqual( + set(mod_pdag.edges()), + {("A", "C"), ("D", "C"), ("B", "A"), ("B", "D"), ("D", "B")} + ) + self.assertEqual(set(mod_pdag.undirected_edges()), {("B", "D")}) + self.assertEqual(set(mod_pdag.directed_edges()), {("A", "C"), ("D", "C"), ("B", "A")}) + + pdag.orient_undirected_edge("B", "A", inplace=True) + self.assertEqual( + set(pdag.edges()), + {("A", "C"), ("D", "C"), ("B", "A"), ("B", "D"), ("D", "B")} + ) + self.assertEqual(set(pdag.undirected_edges()), {("B", "D")}) + self.assertEqual(set(pdag.directed_edges()), {("A", "C"), ("D", "C"), ("B", "A")}) + + with self.assertRaises(ValueError): + pdag.orient_undirected_edge("B", "A", inplace=True) + + def test_copy(self): + pdag_copy = self.pdag_mix.copy() + expected_edges = {("A", "C"), ("D", "C"), ("A", "B"), ("B", "A"), ("B", "D"), ("D", "B")} + expected_dir = [("A", "C"), ("D", "C")] + expected_undir = [("B", "A"), ("B", "D")] + self.assertEqual(set(pdag_copy.edges()), expected_edges) + self.assertEqual(set(pdag_copy.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag_copy.directed_edges()), set(expected_dir)) + self.assertEqual(set(pdag_copy.undirected_edges()), set(expected_undir)) + self.assertEqual(set(pdag_copy.latents), set()) + + pdag_copy = self.pdag_latent.copy() + self.assertEqual(set(pdag_copy.edges()), expected_edges) + self.assertEqual(set(pdag_copy.nodes()), {"A", "B", "C", "D"}) + self.assertEqual(set(pdag_copy.directed_edges()), set(expected_dir)) + self.assertEqual(set(pdag_copy.undirected_edges()), set(expected_undir)) + self.assertEqual(set(pdag_copy.latents), {"A", "D"}) + + def test_pdag_to_dag(self): + # PDAG no: 1 - Possibility of creating a v-structure + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("C", "B")], directed=True) + pdag.add_edges_from([("C", "D"), ("D", "A")], directed=False) + dag = pdag.to_dag() + self.assertTrue(("A", "B") in dag.edges()) + self.assertTrue(("C", "B") in dag.edges()) + self.assertFalse(("A", "D") in dag.edges() and ("C", "D") in dag.edges()) + self.assertEqual(len(dag.edges()), 4) + + # With latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, False, False]) + pdag.add_edges_from([("A", "B"), ("C", "B")], directed=True) + pdag.add_edges_from([("C", "D"), ("D", "A")], directed=False) + dag = pdag.to_dag() + self.assertTrue(("A", "B") in dag.edges()) + self.assertTrue(("C", "B") in dag.edges()) + self.assertFalse(("A", "D") in dag.edges() and ("C", "D") in dag.edges()) + self.assertEqual(set(dag.latents), {"A"}) + self.assertEqual(len(dag.edges()), 4) + + # PDAG no: 2 - No possibility of creating a v-structure + pdag = PDAG() + pdag.add_edges_from([("B", "C"), ("A", "C")], directed=True) + pdag.add_edges_from([("A", "D")], directed=False) + dag = pdag.to_dag() + self.assertTrue(("B", "C") in dag.edges()) + self.assertTrue(("A", "C") in dag.edges()) + self.assertTrue(("A", "D") in dag.edges() or ("D", "A") in dag.edges()) + + # With latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, False, False]) + pdag.add_edges_from([("B", "C"), ("A", "C")], directed=True) + pdag.add_edges_from([("A", "D")], directed=False) + dag = pdag.to_dag() + self.assertTrue(("B", "C") in dag.edges()) + self.assertTrue(("A", "C") in dag.edges()) + self.assertTrue(("A", "D") in dag.edges() or ("D", "A") in dag.edges()) + self.assertEqual(set(dag.latents), {"A"}) + + # PDAG no: 3 - Already existing v-structure, possibility to add another + pdag = PDAG() + pdag.add_edges_from([("B", "C"), ("A", "C")], directed=True) + pdag.add_edges_from([("C", "D")], directed=False) + dag = pdag.to_dag() + expected_edges = {("B", "C"), ("C", "D"), ("A", "C")} + self.assertEqual(set(dag.edges()), expected_edges) + + # With latents + pdag = PDAG() + pdag.add_nodes_from(["A", "B", "C", "D"], latent=[True, False, False, False]) + pdag.add_edges_from([("B", "C"), ("A", "C")], directed=True) + pdag.add_edges_from([("C", "D")], directed=False) + dag = pdag.to_dag() + self.assertEqual(set(dag.edges()), expected_edges) + self.assertEqual(set(dag.latents), {"A"}) + + def test_pdag_to_cpdag(self): + # Test case 1 + pdag = PDAG() + pdag.add_edges_from([("A", "B")], directed=True) + pdag.add_edges_from([("B", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("B", "C")}) + + # Test case 2 + pdag = PDAG() + pdag.add_edges_from([("A", "B")], directed=True) + pdag.add_edges_from([("B", "C"), ("C", "D")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("B", "C"), ("C", "D")}) + + # Test case 3 + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("D", "C")], directed=True) + pdag.add_edges_from([("B", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("D", "C"), ("B", "C"), ("C", "B")}) + + # Test case 4 + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("D", "C"), ("D", "B")], directed=True) + pdag.add_edges_from([("B", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("D", "C"), ("D", "B"), ("B", "C")}) + + # Test case 5 + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("B", "C")], directed=True) + pdag.add_edges_from([("A", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("B", "C"), ("A", "C")}) + + # Test case 6 + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("B", "C"), ("D", "C")], directed=True) + pdag.add_edges_from([("A", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual(set(cpdag.edges()), {("A", "B"), ("B", "C"), ("A", "C"), ("D", "C")}) + + # Perković 2017 example + pdag = PDAG() + pdag.add_edges_from([("V1", "X")], directed=True) + pdag.add_edges_from([("X", "V2"), ("V2", "Y"), ("X", "Y")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual( + set(cpdag.edges()), + {("V1", "X"), ("X", "V2"), ("X", "Y"), ("V2", "Y"), ("Y", "V2")} + ) + + # Perković 2017 example with reversed direction + pdag = PDAG() + pdag.add_edges_from([("Y", "X")], directed=True) + pdag.add_edges_from([("V1", "X"), ("X", "V2"), ("V2", "Y")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual( + set(cpdag.edges()), + {("X", "V1"), ("Y", "X"), ("X", "V2"), ("V2", "X"), ("V2", "Y"), ("Y", "V2")} + ) + + # Bang 2024 example + pdag = PDAG() + pdag.add_edges_from([("B", "D"), ("C", "D")], directed=True) + pdag.add_edges_from([("A", "D"), ("A", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual( + set(cpdag.edges()), {("B", "D"), ("D", "A"), ("C", "A"), ("C", "D")} + ) + + # Bang 2024 example with multiple undirected edges + pdag = PDAG() + pdag.add_edges_from([("A", "B"), ("C", "B")], directed=True) + pdag.add_edges_from([("D", "B"), ("D", "A"), ("D", "C")], directed=False) + cpdag = pdag.apply_meeks_rules(apply_r4=True, inplace=False) + self.assertEqual( + set(cpdag.edges()), + {("A", "B"), ("C", "B"), ("D", "B"), ("D", "A"), ("A", "D"), ("D", "C"), ("C", "D")} + ) + + # Test with inplace=True and apply_r4=False + undirected_edges = [("A", "C"), ("B", "C"), ("D", "C")] + directed_edges = [("B", "D"), ("D", "A")] + pdag = PDAG() + harnessing = pdag.add_edges_from(directed_edges, directed=True) + pdag.add_edges_from(undirected_edges, directed=False) + pdag_inp = pdag.copy() + pdag_inp.apply_meeks_rules(apply_r4=False, inplace=True) + self.assertEqual( + set(pdag_inp.edges()), + {("A", "C"), ("C", "A"), ("C", "B"), ("B", "C"), ("B", "D"), ("D", "A"), ("D", "C"), ("C", "D")} + ) \ No newline at end of file From ed077608e8869d276758841bbd066ee91a92d42b Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 14:15:17 +0530 Subject: [PATCH 05/13] pdag to dag more tests --- rust_core/tests/pdag_tests.rs | 218 ++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) diff --git a/rust_core/tests/pdag_tests.rs b/rust_core/tests/pdag_tests.rs index 1884844..9ce6ba1 100644 --- a/rust_core/tests/pdag_tests.rs +++ b/rust_core/tests/pdag_tests.rs @@ -333,4 +333,222 @@ fn test_to_dag_basic() { assert!(!(dag.has_edge("A", "D") && dag.has_edge("C", "D"))); // No V-structure assert!(dag_edges.contains(&("C".to_string(), "D".to_string())) || dag_edges.contains(&("D".to_string(), "C".to_string()))); assert!(dag_edges.contains(&("D".to_string(), "A".to_string())) || dag_edges.contains(&("A".to_string(), "D".to_string()))); +} + +#[test] +fn test_pdag_to_dag() { + // PDAG no: 1 - Possibility of creating a v-structure + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from( + Some(vec![ + ("C".to_string(), "D".to_string()), + ("D".to_string(), "A".to_string()), + ]), + None, + false, + ) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + + assert_eq!(dag_edges.len(), 4, "Expected 4 edges in DAG"); + assert!(dag.has_edge("A", "B"), "Expected edge A -> B"); + assert!(dag.has_edge("C", "B"), "Expected edge C -> B"); + assert!( + !(dag.has_edge("A", "D") && dag.has_edge("C", "D")), + "Should not have both A -> D and C -> D (v-structure)" + ); + assert!( + dag_edges.contains(&("C".to_string(), "D".to_string())) + || dag_edges.contains(&("D".to_string(), "C".to_string())), + "Expected either C -> D or D -> C" + ); + assert!( + dag_edges.contains(&("D".to_string(), "A".to_string())) + || dag_edges.contains(&("A".to_string(), "D".to_string())), + "Expected either D -> A or A -> D" + ); + + // With latents + let mut pdag = RustPDAG::new(); + pdag.add_nodes_from( + vec!["A".to_string(), "B".to_string(), "C".to_string(), "D".to_string()], + Some(vec![true, false, false, false]), + ) + .unwrap(); + pdag.add_edges_from( + Some(vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from( + Some(vec![ + ("C".to_string(), "D".to_string()), + ("D".to_string(), "A".to_string()), + ]), + None, + false, + ) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + let dag_latents: HashSet = dag.latents.clone().into_iter().collect(); + + assert_eq!(dag_edges.len(), 4, "Expected 4 edges in DAG with latents"); + assert!(dag.has_edge("A", "B"), "Expected edge A -> B with latents"); + assert!(dag.has_edge("C", "B"), "Expected edge C -> B with latents"); + assert!( + !(dag.has_edge("A", "D") && dag.has_edge("C", "D")), + "Should not have both A -> D and C -> D with latents (v-structure)" + ); + assert!( + dag_edges.contains(&("C".to_string(), "D".to_string())) + || dag_edges.contains(&("D".to_string(), "C".to_string())), + "Expected either C -> D or D -> C with latents" + ); + assert!( + dag_edges.contains(&("D".to_string(), "A".to_string())) + || dag_edges.contains(&("A".to_string(), "D".to_string())), + "Expected either D -> A or A -> D with latents" + ); + assert_eq!( + dag_latents, + HashSet::from_iter(vec!["A".to_string()]), + "Expected latent node A" + ); + + // PDAG no: 2 - No possibility of creating a v-structure + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![ + ("B".to_string(), "C".to_string()), + ("A".to_string(), "C".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from(Some(vec![("A".to_string(), "D".to_string())]), None, false) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + + assert!(dag.has_edge("B", "C"), "Expected edge B -> C"); + assert!(dag.has_edge("A", "C"), "Expected edge A -> C"); + assert!( + dag_edges.contains(&("A".to_string(), "D".to_string())) + || dag_edges.contains(&("D".to_string(), "A".to_string())), + "Expected either A -> D or D -> A" + ); + + // With latents + let mut pdag = RustPDAG::new(); + pdag.add_nodes_from( + vec!["A".to_string(), "B".to_string(), "C".to_string(), "D".to_string()], + Some(vec![true, false, false, false]), + ) + .unwrap(); + pdag.add_edges_from( + Some(vec![ + ("B".to_string(), "C".to_string()), + ("A".to_string(), "C".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from(Some(vec![("A".to_string(), "D".to_string())]), None, false) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + let dag_latents: HashSet = dag.latents.clone().into_iter().collect(); + + assert!(dag.has_edge("B", "C"), "Expected edge B -> C with latents"); + assert!(dag.has_edge("A", "C"), "Expected edge A -> C with latents"); + assert!( + dag_edges.contains(&("A".to_string(), "D".to_string())) + || dag_edges.contains(&("D".to_string(), "A".to_string())), + "Expected either A -> D or D -> A with latents" + ); + assert_eq!( + dag_latents, + HashSet::from_iter(vec!["A".to_string()]), + "Expected latent node A" + ); + + // PDAG no: 3 - Already existing v-structure, possibility to add another + let mut pdag = RustPDAG::new(); + pdag.add_edges_from( + Some(vec![ + ("B".to_string(), "C".to_string()), + ("A".to_string(), "C".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from(Some(vec![("C".to_string(), "D".to_string())]), None, false) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("B".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ("A".to_string(), "C".to_string()), + ] + .into_iter() + .collect(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + + assert_eq!(dag_edges, expected_edges, "Expected edges for PDAG no: 3"); + + // With latents + let mut pdag: RustPDAG = RustPDAG::new(); + pdag.add_nodes_from( + vec!["A".to_string(), "B".to_string(), "C".to_string(), "D".to_string()], + Some(vec![true, false, false, false]), + ) + .unwrap(); + pdag.add_edges_from( + Some(vec![ + ("B".to_string(), "C".to_string()), + ("A".to_string(), "C".to_string()), + ]), + None, + true, + ) + .unwrap(); + pdag.add_edges_from(Some(vec![("C".to_string(), "D".to_string())]), None, false) + .unwrap(); + + let dag = pdag.to_dag().unwrap(); + let dag_edges: HashSet<(String, String)> = dag.edges().into_iter().collect(); + let dag_latents: HashSet = dag.latents.into_iter().collect(); + + assert_eq!( + dag_edges, expected_edges, + "Expected edges for PDAG no: 3 with latents" + ); + assert_eq!( + dag_latents, + HashSet::from_iter(vec!["A".to_string()]), + "Expected latent node A" + ); } \ No newline at end of file From 2868193292f157531187e4bbafddeba4382ea446 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 15:04:08 +0530 Subject: [PATCH 06/13] add latent method to py DAG --- python_bindings/src/lib.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python_bindings/src/lib.rs b/python_bindings/src/lib.rs index 0e352f0..b1dbfb5 100644 --- a/python_bindings/src/lib.rs +++ b/python_bindings/src/lib.rs @@ -287,6 +287,13 @@ impl PyRustDAG { .minimal_dseparator(&start, &end, include_latents) .map_err(PyValueError::new_err) } + + #[getter] + fn latents(&self) -> Vec { + let mut result: Vec = self.inner.latents.iter().cloned().collect(); + result.sort(); + result + } } // Existing PyIndependenceAssertion and PyIndependencies (unchanged, included for completeness) From 6dee738500c94915e87a5148e10fd9668a1bec58 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 15:05:43 +0530 Subject: [PATCH 07/13] fix direction typo --- rust_core/src/pdag.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index c917fa6..b2da83a 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -553,7 +553,7 @@ impl RustPDAG { let mut all_predecessors: Vec = pdag_copy.all_neighbors(x)?.into_iter().collect(); all_predecessors.sort(); for y in &all_predecessors { - if pdag_copy.is_adjacent(y, x) && !dag.has_edge(x, y) { + if pdag_copy.is_adjacent(y, x) && !dag.has_edge(y, x) { dag.add_edge(y.clone(), x.clone(), None)?; } } From 388a2bd466a995ce2b40e43bff8cbdc73390d976 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 16:04:49 +0530 Subject: [PATCH 08/13] fix meeks rule 3 --- rust_core/src/pdag.rs | 63 +++++++++------- rust_core/tests/pdag_tests.rs | 135 ++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 28 deletions(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index b2da83a..848d706 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -364,16 +364,10 @@ impl RustPDAG { for x in &directed_parents { for z in &undirected_neighbors { - if !self.is_adjacent(x, z) - && !self.check_new_unshielded_collider(y, z)? - && !self.has_directed_path(z, y)? - { - // Ensure x -> y exists - if self.has_directed_edge(x, y) { - if self.orient_undirected_edge(y, z, true).is_ok() { - changed = true; - break; - } + if !self.is_adjacent(x, z) && !self.check_new_unshielded_collider(y, z)? { + if self.orient_undirected_edge(y, z, true).is_ok() { + changed = true; + break; } } } @@ -381,8 +375,9 @@ impl RustPDAG { } if changed { break; } } + if changed { continue; } - // Rule 2: If X -> Z and Z -> Y and X - Y => X -> Y + // Rule 2: If X -> Z -> Y and X - Y => X -> Y for z in &nodes { if !self.node_map.contains_key(z) { continue; @@ -409,13 +404,13 @@ impl RustPDAG { } if changed { break; } } + if changed { continue; } // Rule 3: If X - Y, X - Z, X - W and Y -> W, Z -> W => X -> W for x in &nodes { if !self.node_map.contains_key(x) { continue; } - // Convert HashSet to sorted vector for deterministic iteration let mut undirected_nbs: Vec = self.undirected_neighbors(x)?.into_iter().collect(); undirected_nbs.sort(); @@ -425,10 +420,20 @@ impl RustPDAG { for i in 0..undirected_nbs.len() { for j in (i + 1)..undirected_nbs.len() { - for k in (j + 1)..undirected_nbs.len() { - let (y, z, w) = (&undirected_nbs[i], &undirected_nbs[j], &undirected_nbs[k]); + let y = &undirected_nbs[i]; + let z = &undirected_nbs[j]; + + if self.is_adjacent(y, z) { + continue; + } + + let y_children = self.directed_children(y)?; + let z_children = self.directed_children(z)?; - if self.has_directed_edge(y, w) && self.has_directed_edge(z, w) { + let common_children: HashSet<_> = y_children.intersection(&z_children).collect(); + + for w in common_children { + if self.has_undirected_edge(x, w) { if self.orient_undirected_edge(x, w, true).is_ok() { changed = true; break; @@ -441,6 +446,7 @@ impl RustPDAG { } if changed { break; } } + if changed { continue; } // Rule 4 if apply_r4 { @@ -460,19 +466,20 @@ impl RustPDAG { continue; } - let mut b_undirected: Vec = self.undirected_neighbors(b)?.into_iter().collect(); - b_undirected.sort(); - let mut c_neighbors: Vec = self.all_neighbors(c)?.into_iter().collect(); - c_neighbors.sort(); - let mut d_undirected: Vec = self.undirected_neighbors(d)?.into_iter().collect(); - d_undirected.sort(); - - for a in &b_undirected { - if c_neighbors.contains(a) && d_undirected.contains(a) { - if self.orient_undirected_edge(a, b, true).is_ok() { - changed = true; - break; - } + let b_undirected_set = self.undirected_neighbors(b)?; + let c_neighbors_set = self.all_neighbors(c)?; + let d_undirected_set = self.undirected_neighbors(d)?; + + let candidates: HashSet<_> = b_undirected_set.intersection(&c_neighbors_set).cloned().collect(); + let final_candidates: HashSet<_> = candidates.intersection(&d_undirected_set).cloned().collect(); + + let mut sorted_candidates: Vec = final_candidates.into_iter().collect(); + sorted_candidates.sort(); + + for a in sorted_candidates { + if self.orient_undirected_edge(&a, b, true).is_ok() { + changed = true; + break; } } if changed { break; } diff --git a/rust_core/tests/pdag_tests.rs b/rust_core/tests/pdag_tests.rs index 9ce6ba1..c2cba32 100644 --- a/rust_core/tests/pdag_tests.rs +++ b/rust_core/tests/pdag_tests.rs @@ -293,6 +293,28 @@ fn test_apply_meeks_rules_no_change() { assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); } + +#[test] +fn test_apply_meeks_rules_no_change_2() { + // Test case where no rules apply + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string()), ("D".to_string(), "C".to_string()), ("D".to_string(), "B".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("B".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "B".to_string()), + ("D".to_string(), "C".to_string()), + ("D".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ].into_iter().collect(); + + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + + + + #[test] fn test_apply_meeks_rules_inplace() { // Test inplace modification @@ -308,6 +330,119 @@ fn test_apply_meeks_rules_inplace() { } +#[test] +fn test_meeks_rules_perkovic_2017() { + // Test case from Perkoviċ et al., 2017 + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("V1".to_string(), "X".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("X".to_string(), "V2".to_string()), ("V2".to_string(), "Y".to_string()), ("X".to_string(), "Y".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("V1".to_string(), "X".to_string()), + ("X".to_string(), "V2".to_string()), + ("X".to_string(), "Y".to_string()), + ("V2".to_string(), "Y".to_string()), + ("Y".to_string(), "V2".to_string()) + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); + + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("Y".to_string(), "X".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("V1".to_string(), "X".to_string()), ("X".to_string(), "V2".to_string()), ("V2".to_string(), "Y".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("X".to_string(), "V1".to_string()), + ("Y".to_string(), "X".to_string()), + ("X".to_string(), "V2".to_string()), + ("V2".to_string(), "X".to_string()), + ("V2".to_string(), "Y".to_string()), + ("Y".to_string(), "V2".to_string()), + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + + +#[test] +fn test_meeks_rules_bang_2024() { + // Test case from Bang et al., 2024 + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("B".to_string(), "D".to_string()), ("C".to_string(), "D".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("A".to_string(), "D".to_string()), ("A".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("B".to_string(), "D".to_string()), + ("D".to_string(), "A".to_string()), + ("C".to_string(), "A".to_string()), + ("C".to_string(), "D".to_string()) + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); + + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string()), ("C".to_string(), "B".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("D".to_string(), "B".to_string()), ("D".to_string(), "A".to_string()), ("D".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ("D".to_string(), "B".to_string()), + ("D".to_string(), "A".to_string()), + ("A".to_string(), "D".to_string()), + ("D".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + +#[test] +fn test_meeks_rules_complex_cases() { + let undirected_edges = vec![("A".to_string(), "C".to_string()), ("B".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]; + let directed_edges = vec![("B".to_string(), "D".to_string()), ("D".to_string(), "A".to_string())]; + + // With apply_r4 = true + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(directed_edges.clone()), None, true).unwrap(); + pdag.add_edges_from(Some(undirected_edges.clone()), None, false).unwrap(); + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges_r4: HashSet<(String, String)> = vec![ + ("C".to_string(), "A".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "A".to_string()), + ("D".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges_r4); + + // With apply_r4 = false + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(directed_edges.clone()), None, true).unwrap(); + pdag.add_edges_from(Some(undirected_edges.clone()), None, false).unwrap(); + let cpdag = pdag.apply_meeks_rules(false, false).unwrap().unwrap(); + let expected_edges_no_r4: HashSet<(String, String)> = vec![ + ("A".to_string(), "C".to_string()), + ("C".to_string(), "A".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("B".to_string(), "D".to_string()), + ("D".to_string(), "A".to_string()), + ("D".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ].into_iter().collect(); + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges_no_r4); + + // With apply_r4 = false and inplace = true + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(directed_edges.clone()), None, true).unwrap(); + pdag.add_edges_from(Some(undirected_edges.clone()), None, false).unwrap(); + pdag.apply_meeks_rules(false, true).unwrap(); + assert_eq!(pdag.edges().into_iter().collect::>(), expected_edges_no_r4); +} + #[test] fn test_to_dag_basic() { From c90beea95f6e08f602c4573fc92ac9d4c8369875 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 20:19:42 +0530 Subject: [PATCH 09/13] minor fixes --- python_bindings/src/lib.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python_bindings/src/lib.rs b/python_bindings/src/lib.rs index b1dbfb5..2b39a67 100644 --- a/python_bindings/src/lib.rs +++ b/python_bindings/src/lib.rs @@ -118,6 +118,29 @@ impl PyRustPDAG { self.inner.edges() } + /// Get all directed edges in the PDAG. + pub fn directed_edges(&self) -> Vec<(String, String)> { + let mut edges: Vec<(String, String)> = self.inner.directed_edges.iter().cloned().collect(); + edges.sort(); + edges + } + + /// Get all undirected edges in the PDAG (returns one direction per pair). + pub fn undirected_edges(&self) -> Vec<(String, String)> { + let mut edges: Vec<(String, String)> = self.inner.undirected_edges.iter().cloned().collect(); + edges.sort(); + edges + } + + /// Orient an undirected edge u - v as u -> v. + #[pyo3(signature = (u, v, inplace = false))] + pub fn orient_undirected_edge(&mut self, u: String, v: String, inplace: bool) -> PyResult> { + self.inner + .orient_undirected_edge(&u, &v, inplace) + .map(|opt| opt.map(|pdag| PyRustPDAG { inner: pdag })) + .map_err(PyValueError::new_err) + } + /// Apply Meek's rules to orient undirected edges. #[pyo3(signature = (apply_r4 = true, inplace = true))] pub fn apply_meeks_rules(&mut self, apply_r4: bool, inplace: bool) -> PyResult> { @@ -152,6 +175,13 @@ impl PyRustPDAG { self.inner.directed_edges.len() + self.inner.undirected_edges.len() } + // Copy PDAG to a new instance. + pub fn copy(&self) -> Self { + PyRustPDAG { + inner: self.inner.clone(), + } + } + /// Get the set of latent nodes. #[getter] pub fn latents(&self) -> Vec { From b6a3fac9321439213589ed22519477da012575e5 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 20:33:08 +0530 Subject: [PATCH 10/13] major dfs fix --- rust_core/src/pdag.rs | 55 ++++++++++++++++++++--------------- rust_core/tests/pdag_tests.rs | 19 ++++++++++++ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index 848d706..12afc3b 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -278,16 +278,19 @@ impl RustPDAG { return Err(format!("Undirected Edge {} - {} not present in the PDAG", u, v)); } - // Remove the reverse edge from the graph let u_idx = pdag.node_map[u]; let v_idx = pdag.node_map[v]; - // Find and remove the edge v -> u + // Remove both directions from graph to avoid duplicates if let Some(edge_idx) = pdag.graph.find_edge(v_idx, u_idx) { pdag.graph.remove_edge(edge_idx); } + if let Some(edge_idx) = pdag.graph.find_edge(u_idx, v_idx) { + pdag.graph.remove_edge(edge_idx); + } - // Add to directed edges + // Add the directed edge u -> v to the graph + pdag.graph.add_edge(u_idx, v_idx, 1.0); pdag.directed_edges.insert((u.to_string(), v.to_string())); if inplace { @@ -297,6 +300,7 @@ impl RustPDAG { } } + /// Check if orienting u -> v would create a new unshielded collider fn check_new_unshielded_collider(&self, u: &str, v: &str) -> Result { let parents = self.directed_parents(v)?; @@ -311,12 +315,16 @@ impl RustPDAG { /// Check if there's a path from source to target in the directed subgraph pub fn has_directed_path(&self, source: &str, target: &str) -> Result { - let source_idx = self.node_map.get(source) + let directed_graph = self.directed_graph(); + + let source_idx = directed_graph.node_map.get(source) .ok_or_else(|| format!("Node {} not found", source))?; - let target_idx = self.node_map.get(target) + let target_idx = directed_graph.node_map.get(target) .ok_or_else(|| format!("Node {} not found", target))?; - let directed_graph = self.directed_graph(); + + println!("Directed graph edges: {:?}", directed_graph.edges()); + println!("Source index: {:?}, Target index: {:?}", source_idx, target_idx); let mut dfs = Dfs::new(&directed_graph.graph, *source_idx); while let Some(nx) = dfs.next(&directed_graph.graph) { @@ -356,33 +364,38 @@ impl RustPDAG { if !self.node_map.contains_key(y) { continue; } - // Convert HashSets to sorted vectors for deterministic iteration let mut directed_parents: Vec = self.directed_parents(y)?.into_iter().collect(); directed_parents.sort(); let mut undirected_neighbors: Vec = self.undirected_neighbors(y)?.into_iter().collect(); undirected_neighbors.sort(); + println!("Rule 1: Y={}, Parents={:?}, Undirected Neighbors={:?}", y, directed_parents, undirected_neighbors); for x in &directed_parents { for z in &undirected_neighbors { - if !self.is_adjacent(x, z) && !self.check_new_unshielded_collider(y, z)? { + let adj = self.is_adjacent(x, z); + let collider = self.check_new_unshielded_collider(y, z)?; + let path = self.has_directed_path(z, y)?; + if !adj && !collider && !path { + println!("Orienting {} -> {} for Rule 1", y, z); if self.orient_undirected_edge(y, z, true).is_ok() { changed = true; break; - } + } } } if changed { break; } } if changed { break; } } - if changed { continue; } + if changed { + continue; + } // Rule 2: If X -> Z -> Y and X - Y => X -> Y for z in &nodes { if !self.node_map.contains_key(z) { continue; } - // Convert HashSets to sorted vectors for deterministic iteration let mut parents: Vec = self.directed_parents(z)?.into_iter().collect(); parents.sort(); let mut children: Vec = self.directed_children(z)?.into_iter().collect(); @@ -391,8 +404,10 @@ impl RustPDAG { for x in &parents { for y in &children { if self.has_undirected_edge(x, y) { - // Ensure x -> z and z -> y exist - if self.has_directed_edge(x, z) && self.has_directed_edge(z, y) { + let x_to_z = self.has_directed_edge(x, z); + let z_to_y = self.has_directed_edge(z, y); + if x_to_z && z_to_y { + println!("Orienting {} -> {} for Rule 2", x, y); if self.orient_undirected_edge(x, y, true).is_ok() { changed = true; break; @@ -404,7 +419,9 @@ impl RustPDAG { } if changed { break; } } - if changed { continue; } + if changed { + continue; + } // Rule 3: If X - Y, X - Z, X - W and Y -> W, Z -> W => X -> W for x in &nodes { @@ -422,16 +439,12 @@ impl RustPDAG { for j in (i + 1)..undirected_nbs.len() { let y = &undirected_nbs[i]; let z = &undirected_nbs[j]; - if self.is_adjacent(y, z) { continue; } - let y_children = self.directed_children(y)?; let z_children = self.directed_children(z)?; - let common_children: HashSet<_> = y_children.intersection(&z_children).collect(); - for w in common_children { if self.has_undirected_edge(x, w) { if self.orient_undirected_edge(x, w, true).is_ok() { @@ -454,7 +467,6 @@ impl RustPDAG { if !self.node_map.contains_key(c) { continue; } - let mut children: Vec = self.directed_children(c)?.into_iter().collect(); children.sort(); let mut parents: Vec = self.directed_parents(c)?.into_iter().collect(); @@ -465,17 +477,13 @@ impl RustPDAG { if b == d || self.is_adjacent(b, d) { continue; } - let b_undirected_set = self.undirected_neighbors(b)?; let c_neighbors_set = self.all_neighbors(c)?; let d_undirected_set = self.undirected_neighbors(d)?; - let candidates: HashSet<_> = b_undirected_set.intersection(&c_neighbors_set).cloned().collect(); let final_candidates: HashSet<_> = candidates.intersection(&d_undirected_set).cloned().collect(); - let mut sorted_candidates: Vec = final_candidates.into_iter().collect(); sorted_candidates.sort(); - for a in sorted_candidates { if self.orient_undirected_edge(&a, b, true).is_ok() { changed = true; @@ -490,7 +498,6 @@ impl RustPDAG { } } } - Ok(()) } diff --git a/rust_core/tests/pdag_tests.rs b/rust_core/tests/pdag_tests.rs index c2cba32..2c18d11 100644 --- a/rust_core/tests/pdag_tests.rs +++ b/rust_core/tests/pdag_tests.rs @@ -293,6 +293,25 @@ fn test_apply_meeks_rules_no_change() { assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); } +#[test] +fn test_apply_meeks_rules_no_change_1() { + // Test case where no rules apply + let mut pdag = RustPDAG::new(); + pdag.add_edges_from(Some(vec![("A".to_string(), "B".to_string()), ("B".to_string(), "C".to_string()), ("D".to_string(), "C".to_string())]), None, true).unwrap(); + pdag.add_edges_from(Some(vec![("A".to_string(), "C".to_string())]), None, false).unwrap(); + + let cpdag = pdag.apply_meeks_rules(true, false).unwrap().unwrap(); + let expected_edges: HashSet<(String, String)> = vec![ + ("A".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("A".to_string(), "C".to_string()), + ("D".to_string(), "C".to_string()), + ].into_iter().collect(); + + assert_eq!(cpdag.edges().into_iter().collect::>(), expected_edges); +} + + #[test] fn test_apply_meeks_rules_no_change_2() { From cbd91744dd0fd01824b5671816cf019cc03ec6d0 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 15 Aug 2025 20:36:11 +0530 Subject: [PATCH 11/13] refactor --- rust_core/src/pdag.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/rust_core/src/pdag.rs b/rust_core/src/pdag.rs index 12afc3b..7a4f51b 100644 --- a/rust_core/src/pdag.rs +++ b/rust_core/src/pdag.rs @@ -323,8 +323,6 @@ impl RustPDAG { .ok_or_else(|| format!("Node {} not found", target))?; - println!("Directed graph edges: {:?}", directed_graph.edges()); - println!("Source index: {:?}, Target index: {:?}", source_idx, target_idx); let mut dfs = Dfs::new(&directed_graph.graph, *source_idx); while let Some(nx) = dfs.next(&directed_graph.graph) { @@ -368,7 +366,6 @@ impl RustPDAG { directed_parents.sort(); let mut undirected_neighbors: Vec = self.undirected_neighbors(y)?.into_iter().collect(); undirected_neighbors.sort(); - println!("Rule 1: Y={}, Parents={:?}, Undirected Neighbors={:?}", y, directed_parents, undirected_neighbors); for x in &directed_parents { for z in &undirected_neighbors { @@ -376,7 +373,6 @@ impl RustPDAG { let collider = self.check_new_unshielded_collider(y, z)?; let path = self.has_directed_path(z, y)?; if !adj && !collider && !path { - println!("Orienting {} -> {} for Rule 1", y, z); if self.orient_undirected_edge(y, z, true).is_ok() { changed = true; break; @@ -407,7 +403,6 @@ impl RustPDAG { let x_to_z = self.has_directed_edge(x, z); let z_to_y = self.has_directed_edge(z, y); if x_to_z && z_to_y { - println!("Orienting {} -> {} for Rule 2", x, y); if self.orient_undirected_edge(x, y, true).is_ok() { changed = true; break; From 01fbdf9760e5d4080b86068df267db815ee25ad2 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sun, 24 Aug 2025 13:49:44 +0530 Subject: [PATCH 12/13] add PDAG wasm bindings --- wasm_bindings/js/tests/test-pdag.js | 81 ++++++++++++++++++ wasm_bindings/src/lib.rs | 122 ++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 wasm_bindings/js/tests/test-pdag.js diff --git a/wasm_bindings/js/tests/test-pdag.js b/wasm_bindings/js/tests/test-pdag.js new file mode 100644 index 0000000..e4ee7a7 --- /dev/null +++ b/wasm_bindings/js/tests/test-pdag.js @@ -0,0 +1,81 @@ +const cg = require("../pkg-node/causalgraphs_wasm.js"); + + +describe('cg.PDAG', () => { + it('can be instantiated', () => { + const pdag = new cg.PDAG(); + expect(pdag.nodeCount).toBe(0); + expect(pdag.edgeCount).toBe(0); + }); + + it('can add nodes and edges', () => { + const pdag = new cg.PDAG(); + pdag.addNode('A'); + pdag.addNode('B'); + pdag.addNode('C'); + pdag.addEdge('A', 'B', null, true); // A -> B + pdag.addEdge('B', 'C', null, false); // B - C + + expect(pdag.nodeCount).toBe(3); + expect(pdag.edgeCount).toBe(2); + + const directedEdges = new Set(pdag.directedEdges().map(e => e.join(','))); + expect(directedEdges).toEqual(new Set(['A,B'])); + + const undirectedEdges = new Set(pdag.undirectedEdges().map(e => e.sort().join(','))); + expect(undirectedEdges).toEqual(new Set(['B,C'])); + }); + + it('can add multiple edges from a list', () => { + const pdag = new cg.PDAG(); + const directed = [['A', 'B'], ['D', 'C']]; + const undirected = [['B', 'C']]; + pdag.addEdgesFrom(directed, null, true); + pdag.addEdgesFrom(undirected, null, false); + + expect(pdag.nodeCount).toBe(4); + expect(pdag.edgeCount).toBe(3); + expect(pdag.nodes().sort()).toEqual(['A', 'B', 'C', 'D']); + }); + + it("applies Meek's rules correctly (basic case)", () => { + const pdag = new cg.PDAG(); + pdag.addEdge('A', 'B', null, true); // A -> B + pdag.addEdge('B', 'C', null, false); // B - C + + const cpdag = pdag.applyMeeksRules(true, false); + const expectedEdges = new Set(['A,B', 'B,C']); + const actualEdges = new Set(cpdag.edges().map(e => e.join(','))); + + expect(actualEdges).toEqual(expectedEdges); + }); + + it("applies Meek's rules correctly (no change)", () => { + const pdag = new cg.PDAG(); + pdag.addEdgesFrom([['A', 'B'], ['D', 'C']], null, true); + pdag.addEdgesFrom([['B', 'C']], null, false); + + const cpdag = pdag.applyMeeksRules(true, false); + + // Expect B-C to remain undirected + const directed = new Set(cpdag.directedEdges().map(e => e.join(','))); + const undirected = new Set(cpdag.undirectedEdges().map(e => e.sort().join(','))); + + expect(directed).toEqual(new Set(['A,B', 'D,C'])); + expect(undirected).toEqual(new Set(['B,C'])); + }); + + it('converts to a DAG', () => { + const pdag = new cg.PDAG(); + pdag.addEdge('A', 'B', null, true); + pdag.addEdge('B', 'C', null, false); + + const dag = pdag.toDag(); + expect(dag.constructor.name).toBe('RustDAG'); + + const dagEdges = new Set(dag.edges().map(e => e.join(','))); + // to_dag is consistent, so B-C will be oriented B->C in this case + const expectedEdges = new Set(['A,B', 'B,C']); + expect(dagEdges).toEqual(expectedEdges); + }); +}); \ No newline at end of file diff --git a/wasm_bindings/src/lib.rs b/wasm_bindings/src/lib.rs index 76ca96d..c533c0b 100644 --- a/wasm_bindings/src/lib.rs +++ b/wasm_bindings/src/lib.rs @@ -94,6 +94,128 @@ impl RustDAG { } } +#[wasm_bindgen(js_name = PDAG)] +pub struct PDAG { + inner: rust_core::RustPDAG, +} + +#[wasm_bindgen] +impl PDAG { + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + Self { + inner: rust_core::RustPDAG::new(), + } + } + + #[wasm_bindgen(js_name = addNode, catch)] + pub fn add_node(&mut self, node: String, latent: Option) -> Result<(), JsValue> { + self.inner + .add_node(node, latent.unwrap_or(false)) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = addNodesFrom, catch)] + pub fn add_nodes_from( + &mut self, + nodes: Vec, + latent: Option>, + ) -> Result<(), JsValue> { + let latent_bools = latent.map(|v| v.into_iter().map(|x| x != 0).collect()); + self.inner + .add_nodes_from(nodes, latent_bools) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = addEdge, catch)] + pub fn add_edge( + &mut self, + u: String, + v: String, + weight: Option, + directed: bool, + ) -> Result<(), JsValue> { + self.inner + .add_edge(u, v, weight, directed) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = addEdgesFrom, catch)] + pub fn add_edges_from( + &mut self, + ebunch: JsValue, + weights: Option>, + directed: bool, + ) -> Result<(), JsValue> { + let ebunch_vec: Vec<(String, String)> = serde_wasm_bindgen::from_value(ebunch)?; + self.inner + .add_edges_from(Some(ebunch_vec), weights, directed) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = nodes)] + pub fn nodes(&self) -> Vec { + self.inner.nodes() + } + + #[wasm_bindgen(js_name = edges)] + pub fn edges(&self) -> JsValue { + serde_wasm_bindgen::to_value(&self.inner.edges()).unwrap() + } + + #[wasm_bindgen(js_name = directedEdges)] + pub fn directed_edges(&self) -> JsValue { + serde_wasm_bindgen::to_value(&self.inner.directed_edges).unwrap() + } + + #[wasm_bindgen(js_name = undirectedEdges)] + pub fn undirected_edges(&self) -> JsValue { + serde_wasm_bindgen::to_value(&self.inner.undirected_edges).unwrap() + } + + #[wasm_bindgen(js_name = nodeCount, getter)] + pub fn node_count(&self) -> usize { + self.inner.node_map.len() + } + + #[wasm_bindgen(js_name = edgeCount, getter)] + pub fn edge_count(&self) -> usize { + self.inner.directed_edges.len() + self.inner.undirected_edges.len() + } + + #[wasm_bindgen(js_name = latents, getter)] + pub fn latents(&self) -> JsValue { + serde_wasm_bindgen::to_value(&self.inner.latents).unwrap() + } + + #[wasm_bindgen(js_name = applyMeeksRules, catch)] + pub fn apply_meeks_rules( + &mut self, + apply_r4: bool, + inplace: bool, + ) -> Result, JsValue> { + self.inner + .apply_meeks_rules(apply_r4, inplace) + .map(|opt| opt.map(|pdag| PDAG { inner: pdag })) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = toDag, catch)] + pub fn to_dag(&self) -> Result { + self.inner + .to_dag() + .map(|dag| RustDAG { inner: dag }) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = copy)] + pub fn copy(&self) -> PDAG { + PDAG { + inner: self.inner.clone(), + } + } +} + // Optional: Add a start function for debugging or initialization #[wasm_bindgen(start)] pub fn main_js() -> Result<(), JsValue> { From f539ebdb4aff418071c7f752c3fab77f2fb81f40 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sun, 24 Aug 2025 19:38:26 +0530 Subject: [PATCH 13/13] Pdag R bindings --- r_bindings/causalgraphs/NAMESPACE | 2 + r_bindings/causalgraphs/R/extendr-wrappers.R | 54 +++++ r_bindings/causalgraphs/src/rust/src/lib.rs | 227 ++++++++++++++++++ .../tests/testthat/{test.R => test_dag.R} | 0 .../causalgraphs/tests/testthat/test_pdag.R | 135 +++++++++++ 5 files changed, 418 insertions(+) rename r_bindings/causalgraphs/tests/testthat/{test.R => test_dag.R} (100%) create mode 100644 r_bindings/causalgraphs/tests/testthat/test_pdag.R diff --git a/r_bindings/causalgraphs/NAMESPACE b/r_bindings/causalgraphs/NAMESPACE index e16e0fa..273388f 100644 --- a/r_bindings/causalgraphs/NAMESPACE +++ b/r_bindings/causalgraphs/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method("$",PDAG) S3method("$",RDAG) +S3method("[[",PDAG) S3method("[[",RDAG) useDynLib(causalgraphs, .registration = TRUE) diff --git a/r_bindings/causalgraphs/R/extendr-wrappers.R b/r_bindings/causalgraphs/R/extendr-wrappers.R index 64ec308..a59d20a 100644 --- a/r_bindings/causalgraphs/R/extendr-wrappers.R +++ b/r_bindings/causalgraphs/R/extendr-wrappers.R @@ -42,5 +42,59 @@ RDAG$latents <- function() .Call(wrap__RDAG__latents, self) #' @export `[[.RDAG` <- `$.RDAG` +PDAG <- new.env(parent = emptyenv()) + +PDAG$new <- function() .Call(wrap__PDAG__new) + +PDAG$add_node <- function(node, latent) .Call(wrap__PDAG__add_node, self, node, latent) + +PDAG$add_nodes_from <- function(nodes, latent) .Call(wrap__PDAG__add_nodes_from, self, nodes, latent) + +PDAG$add_edge <- function(u, v, weight, directed) .Call(wrap__PDAG__add_edge, self, u, v, weight, directed) + +PDAG$add_edges_from <- function(ebunch, weights, directed) .Call(wrap__PDAG__add_edges_from, self, ebunch, weights, directed) + +PDAG$edges <- function() .Call(wrap__PDAG__edges, self) + +PDAG$nodes <- function() .Call(wrap__PDAG__nodes, self) + +PDAG$node_count <- function() .Call(wrap__PDAG__node_count, self) + +PDAG$edge_count <- function() .Call(wrap__PDAG__edge_count, self) + +PDAG$latents <- function() .Call(wrap__PDAG__latents, self) + +PDAG$directed_edges <- function() .Call(wrap__PDAG__directed_edges, self) + +PDAG$undirected_edges <- function() .Call(wrap__PDAG__undirected_edges, self) + +PDAG$all_neighbors <- function(node) .Call(wrap__PDAG__all_neighbors, self, node) + +PDAG$directed_children <- function(node) .Call(wrap__PDAG__directed_children, self, node) + +PDAG$directed_parents <- function(node) .Call(wrap__PDAG__directed_parents, self, node) + +PDAG$has_directed_edge <- function(u, v) .Call(wrap__PDAG__has_directed_edge, self, u, v) + +PDAG$has_undirected_edge <- function(u, v) .Call(wrap__PDAG__has_undirected_edge, self, u, v) + +PDAG$undirected_neighbors <- function(node) .Call(wrap__PDAG__undirected_neighbors, self, node) + +PDAG$is_adjacent <- function(u, v) .Call(wrap__PDAG__is_adjacent, self, u, v) + +PDAG$copy <- function() .Call(wrap__PDAG__copy, self) + +PDAG$orient_undirected_edge <- function(u, v, inplace) .Call(wrap__PDAG__orient_undirected_edge, self, u, v, inplace) + +PDAG$apply_meeks_rules <- function(apply_r4, inplace) .Call(wrap__PDAG__apply_meeks_rules, self, apply_r4, inplace) + +PDAG$to_dag <- function() .Call(wrap__PDAG__to_dag, self) + +#' @export +`$.PDAG` <- function (self, name) { func <- PDAG[[name]]; environment(func) <- environment(); func } + +#' @export +`[[.PDAG` <- `$.PDAG` + # nolint end diff --git a/r_bindings/causalgraphs/src/rust/src/lib.rs b/r_bindings/causalgraphs/src/rust/src/lib.rs index 9942185..69ff34a 100644 --- a/r_bindings/causalgraphs/src/rust/src/lib.rs +++ b/r_bindings/causalgraphs/src/rust/src/lib.rs @@ -1,5 +1,6 @@ use extendr_api::prelude::*; use rust_core::RustDAG; +use rust_core::RustPDAG; #[extendr] #[derive(Debug, Clone)] @@ -125,10 +126,236 @@ impl RDAG { } } + +#[extendr] +#[derive(Debug, Clone)] +pub struct PDAG { + inner: RustPDAG, +} + + +#[extendr] +impl PDAG { + /// Create a new PDAG + /// @export + fn new() -> Self { + PDAG { inner: RustPDAG::new() } + } + + /// Add a single node + /// @param node Node name + /// @param latent Whether latent (default FALSE) + /// @export + fn add_node(&mut self, node: String, latent: Option) -> extendr_api::Result<()> { + self.inner.add_node(node, latent.unwrap_or(false)) + .map_err(|e| Error::Other(e.to_string())) + } + + /// Add nodes from vector with optional latent mask (NULL means all false) + /// @param nodes character vector + /// @param latent NULL or logical vector + /// @export + fn add_nodes_from(&mut self, nodes: Strings, latent: Nullable) -> extendr_api::Result<()> { + let node_vec: Vec = nodes.iter().map(|s| s.to_string()).collect(); + let latent_opt: Option> = latent.into_option().map(|v| v.iter().map(|x| x.is_true()).collect()); + self.inner.add_nodes_from(node_vec, latent_opt).map_err(|e| Error::Other(e.to_string())) + } + + /// Add single edge (directed or undirected) + /// @param u source + /// @param v target + /// @param weight optional numeric (NULL) + /// @param directed bool (TRUE: directed, FALSE: undirected) + /// @export + fn add_edge(&mut self, u: String, v: String, weight: Nullable, directed: Option) -> extendr_api::Result<()> { + let w = weight.into_option(); + let d = directed.unwrap_or(true); + self.inner.add_edge(u, v, w, d).map_err(|e| Error::Other(e.to_string())) + } + + /// Add multiple edges from an R list of pairs: list(c("A","B"), c("C","D")) + /// @param ebunch list of character vectors length 2 + /// @param weights NULL or numeric vector + /// @param directed bool + /// @export + fn add_edges_from(&mut self, ebunch: List, weights: Nullable, directed: Option) -> extendr_api::Result<()> { + // convert ebunch (List) -> Vec<(String,String)> + let mut edges: Vec<(String,String)> = Vec::with_capacity(ebunch.len()); + for (i, item) in ebunch.values().enumerate() { + // Each item must be a character vector of length 2 + let pair: Strings = item.try_into().map_err(|_| Error::Other(format!("ebunch[{}] must be a character vector of length 2", i)))?; + if pair.len() != 2 { + return Err(Error::Other(format!("ebunch[{}] must have exactly 2 elements", i))); + } + edges.push((pair[0].to_string(), pair[1].to_string())); + } + let weight_opt: Option> = weights.into_option().map(|v| v.iter().map(|d| d.inner()).collect()); + let directed = directed.unwrap_or(true); + self.inner.add_edges_from(Some(edges), weight_opt, directed).map_err(|e| Error::Other(e.to_string())) + } + + /// Return all edges. For PDAG this includes both directed and undirected (both directions placed into graph). + /// Return as list(from = ..., to = ...) same as RDAG$edges() + /// @export + fn edges(&self) -> List { + let edges = self.inner.edges(); + let (from, to): (Vec<_>, Vec<_>) = edges.into_iter().unzip(); + list!(from = from, to = to) + } + + /// Return nodes + /// @export + fn nodes(&self) -> Strings { + self.inner.nodes().iter().map(|s| s.as_str()).collect::() + } + + /// Number of nodes + /// @export + fn node_count(&self) -> i32 { + self.inner.node_map.len() as i32 + } + + /// Number of edges (count unique graph edges) + /// @export + fn edge_count(&self) -> i32 { + self.inner.edges().len() as i32 + } + + /// Latent nodes + /// @export + fn latents(&self) -> Strings { + let mut v: Vec = self.inner.latents.iter().cloned().collect(); + v.sort(); + v.iter().map(|s| s.as_str()).collect::() + } + + /// Directed edges as a list of 2-element character vectors + /// @export + fn directed_edges(&self) -> List { + let mut vec = self.inner.directed_edges.iter().cloned().collect::>(); + vec.sort(); + let mut out = List::new(vec.len()); + for (i, (u, v)) in vec.into_iter().enumerate() { + let pair = vec![u.as_str(), v.as_str()].iter().map(|s| *s).collect::(); + out.set_elt(i, Into::::into(pair)).unwrap(); + } + out + } + + /// Undirected edges reported as stored (u, v) for each undirected pair (original insertion) + /// @export + fn undirected_edges(&self) -> List { + let mut vec = self.inner.undirected_edges.iter().cloned().collect::>(); + vec.sort(); + let mut out = List::new(vec.len()); + for (i, (u, v)) in vec.into_iter().enumerate() { + let pair = vec![u.as_str(), v.as_str()].iter().map(|s| *s).collect::(); + out.set_elt(i, Into::::into(pair)).unwrap(); + } + out + } + + /// All neighbors (directed or undirected) as character vector + /// @export + fn all_neighbors(&self, node: String) -> extendr_api::Result { + let s = self.inner.all_neighbors(&node).map_err(|e| Error::Other(e))?; + let mut v: Vec = s.into_iter().collect(); + v.sort(); + Ok(v.iter().map(|x| x.as_str()).collect::()) + } + + /// Directed children + /// @export + fn directed_children(&self, node: String) -> extendr_api::Result { + let s = self.inner.directed_children(&node).map_err(|e| Error::Other(e))?; + let mut v: Vec = s.into_iter().collect(); + v.sort(); + Ok(v.iter().map(|x| x.as_str()).collect::()) + } + + /// Directed parents + /// @export + fn directed_parents(&self, node: String) -> extendr_api::Result { + let s = self.inner.directed_parents(&node).map_err(|e| Error::Other(e))?; + let mut v: Vec = s.into_iter().collect(); + v.sort(); + Ok(v.iter().map(|x| x.as_str()).collect::()) + } + + /// has_directed_edge + /// @export + fn has_directed_edge(&self, u: String, v: String) -> bool { + self.inner.has_directed_edge(&u, &v) + } + + /// has_undirected_edge + /// @export + fn has_undirected_edge(&self, u: String, v: String) -> bool { + self.inner.has_undirected_edge(&u, &v) + } + + /// undirected_neighbors + /// @export + fn undirected_neighbors(&self, node: String) -> extendr_api::Result { + let s = self.inner.undirected_neighbors(&node).map_err(|e| Error::Other(e))?; + let mut v: Vec = s.into_iter().collect(); + v.sort(); + Ok(v.iter().map(|x| x.as_str()).collect::()) + } + + /// is_adjacent + /// @export + fn is_adjacent(&self, u: String, v: String) -> bool { + self.inner.is_adjacent(&u, &v) + } + + /// copy + /// @export + fn copy(&self) -> PDAG { + PDAG { inner: self.inner.copy() } + } + + /// orient_undirected_edge (returns NULL if inplace = TRUE, otherwise returns new PDAG) + /// @param u + /// @param v + /// @param inplace default TRUE + /// @export + fn orient_undirected_edge(&mut self, u: String, v: String, inplace: Option) -> extendr_api::Result> { + let in_place = inplace.unwrap_or(true); + match self.inner.orient_undirected_edge(&u, &v, in_place) { + Ok(None) => Ok(Nullable::Null), + Ok(Some(pdag)) => Ok(Nullable::NotNull(PDAG { inner: pdag })), + Err(e) => Err(Error::Other(e)), + } + } + + /// apply_meeks_rules (apply_r4 bool, inplace bool) + /// @export + fn apply_meeks_rules(&mut self, apply_r4: Option, inplace: Option) -> extendr_api::Result> { + let apply_r4 = apply_r4.unwrap_or(true); + let inplace = inplace.unwrap_or(false); + match self.inner.apply_meeks_rules(apply_r4, inplace) { + Ok(None) => Ok(Nullable::Null), + Ok(Some(pdag)) => Ok(Nullable::NotNull(PDAG { inner: pdag })), + Err(e) => Err(Error::Other(e)), + } + } + + /// to_dag -> RDAG + /// @export + fn to_dag(&self) -> extendr_api::Result { + let dag = self.inner.to_dag().map_err(|e| Error::Other(e))?; + Ok(RDAG { inner: dag }) + } +} + + + // Macro to generate exports. // This ensures exported functions are registered with R. // See corresponding C code in `entrypoint.c` extendr_module! { mod causalgraphs; impl RDAG; + impl PDAG; } diff --git a/r_bindings/causalgraphs/tests/testthat/test.R b/r_bindings/causalgraphs/tests/testthat/test_dag.R similarity index 100% rename from r_bindings/causalgraphs/tests/testthat/test.R rename to r_bindings/causalgraphs/tests/testthat/test_dag.R diff --git a/r_bindings/causalgraphs/tests/testthat/test_pdag.R b/r_bindings/causalgraphs/tests/testthat/test_pdag.R new file mode 100644 index 0000000..fc24d11 --- /dev/null +++ b/r_bindings/causalgraphs/tests/testthat/test_pdag.R @@ -0,0 +1,135 @@ +library(causalgraphs) +library(testthat) + + +test_that("basic PDAG operations and properties", { + pdag <- PDAG$new() + pdag$add_edges_from(list(c("A", "C"), c("D", "C")), weights = NULL, directed = TRUE) + pdag$add_edges_from(list(c("B", "A"), c("B", "D")), weights = NULL, directed = FALSE) + + expect_setequal(pdag$nodes(), c("A", "B", "C", "D")) + expect_equal(pdag$node_count(), 4L) + expect_equal(pdag$edge_count(), 6L) + + # Check directed edges + dir_edges <- pdag$directed_edges() + expect_length(dir_edges, 2) + expect_setequal(sapply(dir_edges, paste, collapse="->"), c("A->C", "D->C")) + + # Check undirected edges + undir_edges <- pdag$undirected_edges() + expect_length(undir_edges, 2) + # Sorting to ensure consistent comparison + undir_pairs <- sapply(undir_edges, function(x) paste(sort(x), collapse="-")) + expect_setequal(undir_pairs, c("A-B", "B-D")) + + # Check all edges in the representation + all_edges <- pdag$edges() + all_edges_str <- paste0(all_edges$from, "->", all_edges$to) + expect_setequal(all_edges_str, c("A->C", "D->C", "A->B", "B->A", "B->D", "D->B")) +}) + +test_that("PDAG neighbor and parent/child queries work correctly", { + pdag <- PDAG$new() + pdag$add_edges_from(list(c("A", "C"), c("D", "C")), weights = NULL, directed = TRUE) + pdag$add_edges_from(list(c("B", "A"), c("B", "D")), weights = NULL, directed = FALSE) + + expect_setequal(pdag$all_neighbors("A"), c("B", "C")) + expect_setequal(pdag$all_neighbors("B"), c("A", "D")) + expect_setequal(pdag$all_neighbors("C"), c("A", "D")) + expect_setequal(pdag$all_neighbors("D"), c("B", "C")) + + expect_setequal(pdag$directed_children("A"), "C") + expect_length(pdag$directed_children("B"), 0) + expect_setequal(pdag$directed_parents("C"), c("A", "D")) + + expect_setequal(pdag$undirected_neighbors("A"), "B") + expect_setequal(pdag$undirected_neighbors("B"), c("A", "D")) + expect_length(pdag$undirected_neighbors("C"), 0) +}) + + +test_that("PDAG edge existence checks work", { + pdag <- PDAG$new() + pdag$add_edges_from(list(c("A", "C"), c("D", "C")), weights = NULL, directed = TRUE) + pdag$add_edges_from(list(c("B", "A"), c("B", "D")), weights = NULL, directed = FALSE) + + expect_true(pdag$has_directed_edge("A", "C")) + expect_false(pdag$has_directed_edge("C", "A")) + expect_false(pdag$has_directed_edge("A", "B")) + + expect_true(pdag$has_undirected_edge("A", "B")) + expect_true(pdag$has_undirected_edge("B", "A")) + expect_false(pdag$has_undirected_edge("A", "C")) + + expect_true(pdag$is_adjacent("A", "B")) + expect_true(pdag$is_adjacent("A", "C")) + expect_false(pdag$is_adjacent("A", "D")) +}) + + +test_that("PDAG copy and orient_undirected_edge work", { + pdag <- PDAG$new() + pdag$add_edges_from(list(c("A", "C"), c("D", "C")), weights = NULL, directed = TRUE) + pdag$add_edges_from(list(c("B", "A"), c("B", "D")), weights = NULL, directed = FALSE) + + # Test copy + pdag_copy <- pdag$copy() + expect_equal(pdag$nodes(), pdag_copy$nodes()) + expect_equal(pdag$directed_edges(), pdag_copy$directed_edges()) + expect_equal(pdag$undirected_edges(), pdag_copy$undirected_edges()) + + # Test orient_undirected_edge (not in-place) + mod_pdag <- pdag$orient_undirected_edge("B", "A", inplace = FALSE) + expect_false(is.null(mod_pdag)) + expect_setequal(sapply(mod_pdag$directed_edges(), paste, collapse="->"), c("A->C", "D->C", "B->A")) + expect_setequal(sapply(mod_pdag$undirected_edges(), function(x) paste(sort(x), collapse="-")), "B-D") + + # Test orient_undirected_edge (in-place) + pdag$orient_undirected_edge("B", "A", inplace = TRUE) + expect_setequal(sapply(pdag$directed_edges(), paste, collapse="->"), c("A->C", "D->C", "B->A")) + expect_setequal(sapply(pdag$undirected_edges(), function(x) paste(sort(x), collapse="-")), "B-D") + + # Orienting an already directed edge should fail + expect_error(pdag$orient_undirected_edge("B", "A", inplace = TRUE)) +}) + +test_that("PDAG to_dag conversion works", { + pdag <- PDAG$new() + pdag$add_edges_from(list(c("A", "B"), c("C", "B")), weights = NULL, directed = TRUE) + pdag$add_edges_from(list(c("C", "D"), c("D", "A")), weights = NULL, directed = FALSE) + + dag <- pdag$to_dag() + expect_s3_class(dag, "RDAG") + expect_equal(dag$edge_count(), 4L) + + e <- dag$edges() + edges_str <- paste0(e$from, "->", e$to) + expect_true("A->B" %in% edges_str) + expect_true("C->B" %in% edges_str) + # Should not create a v-structure at D + expect_false(all(c("A->D", "C->D") %in% edges_str)) +}) + + +test_that("PDAG apply_meeks_rules works", { + # Test case 1: A -> B - C => A -> B -> C + pdag <- PDAG$new() + pdag$add_edge("A", "B", weight = NULL, directed = TRUE) + pdag$add_edge("B", "C", weight = NULL, directed = FALSE) + cpdag <- pdag$apply_meeks_rules(apply_r4 = TRUE, inplace = FALSE) + + e <- cpdag$edges() + edges_str <- paste0(e$from, "->", e$to) + expect_setequal(edges_str, c("A->B", "B->C")) + + # Test case 2: A -> B, D -> C, B - C => No change (potential v-structure) + pdag2 <- PDAG$new() + pdag2$add_edges_from(list(c("A", "B"), c("D", "C")), weights = NULL, directed = TRUE) + pdag2$add_edge("B", "C", weight = NULL, directed = FALSE) + cpdag2 <- pdag2$apply_meeks_rules(apply_r4 = TRUE, inplace = FALSE) + + e2 <- cpdag2$edges() + edges_str2 <- paste0(e2$from, "->", e2$to) + expect_setequal(edges_str2, c("A->B", "D->C", "B->C", "C->B")) +}) \ No newline at end of file