From cef855cd0d5fd6aa010fb90095d9741938ad428b Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Thu, 11 Sep 2025 23:17:27 +0530 Subject: [PATCH 1/6] * refactor minimal_dseparartor * start & end nodes accept a list --- rust_core/src/dag.rs | 81 ++++++++++++++++++++++++------------- rust_core/tests/test_dag.rs | 14 +++---- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/rust_core/src/dag.rs b/rust_core/src/dag.rs index bf4d31d..7d1cbec 100644 --- a/rust_core/src/dag.rs +++ b/rust_core/src/dag.rs @@ -379,30 +379,44 @@ impl RustDAG { /// Tian, Paz, Pearl (1998), *Finding Minimal d-Separators*. pub fn minimal_dseparator( &self, - start: &str, - end: &str, + starts: Vec, + ends: Vec, include_latents: bool, ) -> Result>, String> { - // Example: For DAG A→B←C, B→D, trying to separate A and C - // Adjacent nodes can't be separated by any conditioning set - if self.has_edge(start, end) || self.has_edge(end, start) { - return Err("No possible separators because start and end are adjacent".to_string()); + // Validate inputs + if starts.is_empty() || ends.is_empty() { + return Ok(Some(HashSet::new())); } - // Create ancestral graph containing only ancestors of start and end - // Example: For separating A and D in A→B←C, B→D, ancestral graph = {A, B, C, D} - let ancestral_graph = self.get_ancestral_graph(vec![start.to_string(), end.to_string()])?; + // Check for adjacent pairs - if any start-end pair is adjacent, no separator exists + for start in &starts { + for end in &ends { + if self.has_edge(start, end) || self.has_edge(end, start) { + return Err(format!( + "No possible separators because {} and {} are adjacent", + start, end + )); + } + } + } - // Initial separator: all parents of both nodes (theoretical upper bound) - // Example: parents(A)={} ∪ parents(D)={B} → separator = {B} - let mut separator: HashSet = self - .get_parents(start)? - .into_iter() - .chain(self.get_parents(end)?.into_iter()) - .collect(); + + // Create ancestral graph containing only ancestors of all starts and ends + let mut all_nodes = starts.clone(); + all_nodes.extend(ends.clone()); + let ancestral_graph = self.get_ancestral_graph(all_nodes)?; + + // Initial separator: all parents of all start and end nodes + let mut separator: HashSet = HashSet::new(); + + for start in &starts { + separator.extend(self.get_parents(start)?); + } + for end in &ends { + separator.extend(self.get_parents(end)?); + } // Replace latent variables with their observable parents - // Example: If B were latent with parent L, replace B with L in separator if !include_latents { let mut changed = true; while changed { @@ -421,21 +435,32 @@ impl RustDAG { } } - separator.remove(start); - separator.remove(end); + // Remove starts and ends from separator (can't separate a node from itself) + for start in &starts { + separator.remove(start); + } + for end in &ends { + separator.remove(end); + } + + // Helper function to check if all start-end pairs are d-separated + let check_all_separated = |sep: &[String]| -> Result { + for start in &starts { + for end in &ends { + if ancestral_graph.is_dconnected(start, end, Some(sep.to_vec()), include_latents)? { + return Ok(false); // Found a connected pair + } + } + } + Ok(true) // All pairs are separated + }; // Sanity check: if our "guaranteed" separator doesn't work, no separator exists - if ancestral_graph.is_dconnected( - start, - end, - Some(separator.iter().cloned().collect()), - include_latents, - )? { + if !check_all_separated(&separator.iter().cloned().collect::>())? { return Ok(None); } // Greedy minimization: remove each node if separation still holds without it - // Example: If separator = {B, C} but {B} alone separates A from D, remove C let mut minimal_separator = separator.clone(); for u in separator { let test_separator: Vec = minimal_separator @@ -444,8 +469,8 @@ impl RustDAG { .filter(|x| x != &u) .collect(); - // If still d-separated WITHOUT this node, we can remove it - if !ancestral_graph.is_dconnected(start, end, Some(test_separator), include_latents)? { + // If all pairs are still d-separated WITHOUT this node, we can remove it + if check_all_separated(&test_separator)? { minimal_separator.remove(&u); } } diff --git a/rust_core/tests/test_dag.rs b/rust_core/tests/test_dag.rs index 01ca853..bb52200 100644 --- a/rust_core/tests/test_dag.rs +++ b/rust_core/tests/test_dag.rs @@ -127,7 +127,7 @@ fn test_minimal_dseparator_simple() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["B".to_string()].into_iter().collect(); assert_eq!(result, Some(expected)); } @@ -146,7 +146,7 @@ fn test_minimal_dseparator_complex() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "D", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["D".to_string()], false).unwrap(); let expected: HashSet = vec!["C".to_string(), "E".to_string()].into_iter().collect(); assert_eq!(result, Some(expected)); } @@ -166,7 +166,7 @@ fn test_minimal_dseparator_latent_case_1() { ) .unwrap(); // No d-separator should exist because B is latent - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); assert_eq!(result, None); } @@ -187,7 +187,7 @@ fn test_minimal_dseparator_latent_case_2() { ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["D".to_string()].into_iter().collect(); assert_eq!( result, @@ -213,7 +213,7 @@ fn test_minimal_dseparator_latent_case_3() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); assert_eq!(result, None, "Expected no d-separator when D is latent with multiple paths A→B→C and A→D→C (D is unobservable, because of its latent status)"); } @@ -238,7 +238,7 @@ fn test_minimal_dseparator_latent_case_5() { None, ) .unwrap(); - let result = dag.minimal_dseparator("A", "C", false).unwrap(); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["C".to_string()], false).unwrap(); let expected: HashSet = vec!["B".to_string(), "D".to_string()].into_iter().collect(); assert_eq!( result, @@ -253,7 +253,7 @@ fn test_minimal_dseparator_adjacent_error() { dag.add_edges_from(vec![("A".to_string(), "B".to_string())], None) .unwrap(); - let result = dag.minimal_dseparator("A", "B", false); + let result = dag.minimal_dseparator(vec!["A".to_string()], vec!["B".to_string()], false); assert!(result.is_err()); assert!(result.unwrap_err().contains("adjacent")); } From 9357448ed0ae7674465cc50f0b0e65daa72d05af Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Thu, 11 Sep 2025 23:36:09 +0530 Subject: [PATCH 2/6] add tests --- rust_core/tests/test_dag.rs | 95 +++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/rust_core/tests/test_dag.rs b/rust_core/tests/test_dag.rs index bb52200..9715abf 100644 --- a/rust_core/tests/test_dag.rs +++ b/rust_core/tests/test_dag.rs @@ -257,3 +257,98 @@ fn test_minimal_dseparator_adjacent_error() { assert!(result.is_err()); assert!(result.unwrap_err().contains("adjacent")); } + +#[test] +fn test_multiple_chains_separation() { + // Graph: A1→B1→C1, A2→B2→C2, A3→B3→C3 (three independent chains) + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A1".to_string(), "B1".to_string()), + ("B1".to_string(), "C1".to_string()), + ("A2".to_string(), "B2".to_string()), + ("B2".to_string(), "C2".to_string()), + ("A3".to_string(), "B3".to_string()), + ("B3".to_string(), "C3".to_string()), + ], + None, + ).unwrap(); + + // Separate {A1, A2, A3} from {C1, C2, C3} + let result = dag.minimal_dseparator( + vec!["A1".to_string(), "A2".to_string(), "A3".to_string()], + vec!["C1".to_string(), "C2".to_string(), "C3".to_string()], + false + ).unwrap(); + + let expected: HashSet = vec![ + "B1".to_string(), "B2".to_string(), "B3".to_string() + ].into_iter().collect(); + assert_eq!(result, Some(expected)); +} + +#[test] +fn test_fork_convergence_pattern() { + // Graph: A→B, A→C, B→D, C→D (fork from A, convergence at D) + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("A".to_string(), "C".to_string()), + ("B".to_string(), "D".to_string()), + ("C".to_string(), "D".to_string()), + ], + None, + ).unwrap(); + + // Separate {B, C} from {D} - should need both B and C + let result = dag.minimal_dseparator( + vec!["B".to_string(), "C".to_string()], + vec!["D".to_string()], + false + ); + + // No separator should exist because B→D and C→D are direct edges + assert!(result.is_err()); + assert!(result.unwrap_err().contains("No possible separators because B and D are adjacent")); +} + +#[test] +fn test_diamond_pattern_multiple() { + // Graph: Multiple diamond patterns + // A1→B1, A1→C1, B1→D1, C1→D1 + // A2→B2, A2→C2, B2→D2, C2→D2 + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A1".to_string(), "B1".to_string()), + ("A1".to_string(), "C1".to_string()), + ("B1".to_string(), "D1".to_string()), + ("C1".to_string(), "D1".to_string()), + ("A2".to_string(), "B2".to_string()), + ("A2".to_string(), "C2".to_string()), + ("B2".to_string(), "D2".to_string()), + ("C2".to_string(), "D2".to_string()), + ], + None, + ).unwrap(); + + // Separate {A1, A2} from {D1, D2} + let result = dag.minimal_dseparator( + vec!["A1".to_string(), "A2".to_string()], + vec!["D1".to_string(), "D2".to_string()], + false + ).unwrap(); + + assert!(result.is_some()); + let separator: HashSet = result.unwrap(); + + // Should include intermediate nodes from both diamonds + assert!(!separator.is_empty()); + // Should contain some combination of B1, C1, B2, C2 + let intermediate_nodes: HashSet = vec![ + "B1".to_string(), "C1".to_string(), "B2".to_string(), "C2".to_string() + ].into_iter().collect(); + + assert_eq!(separator, intermediate_nodes); +} \ No newline at end of file From 567e03450ef1885ceafa1e84bb7078d1273a4dfe Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 12 Sep 2025 00:14:32 +0530 Subject: [PATCH 3/6] R bindings refactors --- r_bindings/causalgraphs/R/extendr-wrappers.R | 2 +- r_bindings/causalgraphs/src/rust/Cargo.toml | 4 ++-- r_bindings/causalgraphs/src/rust/src/lib.rs | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/r_bindings/causalgraphs/R/extendr-wrappers.R b/r_bindings/causalgraphs/R/extendr-wrappers.R index 8431564..f0722ca 100644 --- a/r_bindings/causalgraphs/R/extendr-wrappers.R +++ b/r_bindings/causalgraphs/R/extendr-wrappers.R @@ -46,7 +46,7 @@ RDAG$are_neighbors <- function(start, end) .Call(wrap__RDAG__are_neighbors, self RDAG$get_ancestral_graph <- function(nodes) .Call(wrap__RDAG__get_ancestral_graph, self, nodes) -RDAG$minimal_dseparator <- function(start, end, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, start, end, include_latents) +RDAG$minimal_dseparator <- function(starts, ends, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, starts, ends, include_latents) #' @export `$.RDAG` <- function (self, name) { func <- RDAG[[name]]; environment(func) <- environment(); func } diff --git a/r_bindings/causalgraphs/src/rust/Cargo.toml b/r_bindings/causalgraphs/src/rust/Cargo.toml index 14ec67e..860d497 100644 --- a/r_bindings/causalgraphs/src/rust/Cargo.toml +++ b/r_bindings/causalgraphs/src/rust/Cargo.toml @@ -11,9 +11,9 @@ name = 'rcausalgraphs' [dependencies] -rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" } +# rust_core = { git = "https://github.com/pgmpy/causalgraphs.git", branch = "main", package = "rust_core" } # For local development, comment out the Git line above and uncomment this: -# rust_core = { path = "../../../../rust_core" } +rust_core = { path = "../../../../rust_core" } extendr-api = '*' diff --git a/r_bindings/causalgraphs/src/rust/src/lib.rs b/r_bindings/causalgraphs/src/rust/src/lib.rs index bc092d7..f36fb57 100644 --- a/r_bindings/causalgraphs/src/rust/src/lib.rs +++ b/r_bindings/causalgraphs/src/rust/src/lib.rs @@ -213,8 +213,10 @@ impl RDAG { /// @param end Ending node /// @param include_latents Whether to include latents (default: FALSE) /// @export - fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> extendr_api::Result> { - let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + fn minimal_dseparator(&self, starts: Strings, ends: Strings, include_latents: Option) -> extendr_api::Result> { + let starts_vec: Vec = starts.iter().map(|s| s.to_string()).collect(); + let ends_vec: Vec = ends.iter().map(|s| s.to_string()).collect(); + let result = self.inner.minimal_dseparator(starts_vec.clone(), ends_vec.clone(), include_latents.unwrap_or(false)) .map_err(|e| Error::Other(e.to_string()))?; match result { Some(set) => { From 2e0813617be1f718841c8f171638e33b6b6c3089 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 12 Sep 2025 00:19:14 +0530 Subject: [PATCH 4/6] python bindings refactors --- python_bindings/src/lib.rs | 8 ++++---- python_bindings/tests/test_dag.py | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python_bindings/src/lib.rs b/python_bindings/src/lib.rs index 2b39a67..9abec2b 100644 --- a/python_bindings/src/lib.rs +++ b/python_bindings/src/lib.rs @@ -306,15 +306,15 @@ impl PyRustDAG { .map_err(PyValueError::new_err) } - #[pyo3(signature = (start, end, include_latents=false))] + #[pyo3(signature = (starts, ends, include_latents=false))] pub fn minimal_dseparator( &self, - start: String, - end: String, + starts: Vec, + ends: Vec, include_latents: bool, ) -> PyResult>> { self.inner - .minimal_dseparator(&start, &end, include_latents) + .minimal_dseparator(starts.clone(), ends.clone(), include_latents) .map_err(PyValueError::new_err) } diff --git a/python_bindings/tests/test_dag.py b/python_bindings/tests/test_dag.py index 6312abd..9ed0ab9 100644 --- a/python_bindings/tests/test_dag.py +++ b/python_bindings/tests/test_dag.py @@ -21,45 +21,45 @@ def test_parents_children(self, dag): assert dag.get_children("X") == ["Y"] def test_minimal_dseparator(self): - # Test case: A → B ← C + # Test case: A → B → C dag1 = DAG() dag1.add_edges_from([("A", "B"), ("B", "C")]) - assert dag1.minimal_dseparator("A", "C") == {"B"} + assert dag1.minimal_dseparator(["A"], ["C"]) == {"B"} - # Test case: A → B ← C, B → D, A → E, E → D + # Test case: A → B → C, C → D, A → E, E → D dag2 = DAG() dag2.add_edges_from([("A", "B"), ("B", "C"), ("C", "D"), ("A", "E"), ("E", "D")]) - assert dag2.minimal_dseparator("A", "D") == {"C", "E"} + assert dag2.minimal_dseparator(["A"], ["D"]) == {"C", "E"} # Test case: B → A, B → C, A → D, D → C, A → E, C → E dag3 = DAG() dag3.add_edges_from([("B", "A"), ("B", "C"), ("A", "D"), ("D", "C"), ("A", "E"), ("C", "E")]) - assert dag3.minimal_dseparator("A", "C") == {"B", "D"} + assert dag3.minimal_dseparator(["A"], ["C"]) == {"B", "D"} # Test with latents dag_lat1 = DAG() dag_lat1.add_nodes_from(["A", "B", "C"], latent=[False, True, False]) dag_lat1.add_edges_from([("A", "B"), ("B", "C")]) - assert dag_lat1.minimal_dseparator("A", "C") is None - # assert dag_lat1.minimal_dseparator("A", "C", include_latents=True) == {"B"} + assert dag_lat1.minimal_dseparator(["A"], ["C"]) is None + # assert dag_lat1.minimal_dseparator(["A"], ["C"], include_latents=True) == {"B"} dag_lat2 = DAG() dag_lat2.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False]) dag_lat2.add_edges_from([("A", "D"), ("D", "B"), ("B", "C")]) - assert dag_lat2.minimal_dseparator("A", "C") == {"D"} + assert dag_lat2.minimal_dseparator(["A"], ["C"]) == {"D"} dag_lat3 = DAG() dag_lat3.add_nodes_from(["A", "B", "C", "D"], latent=[False, True, False, False]) dag_lat3.add_edges_from([("A", "B"), ("B", "D"), ("D", "C")]) - assert dag_lat3.minimal_dseparator("A", "C") == {"D"} + assert dag_lat3.minimal_dseparator(["A"], ["C"]) == {"D"} dag_lat4 = DAG() dag_lat4.add_nodes_from(["A", "B", "C", "D"], latent=[False, False, False, True]) dag_lat4.add_edges_from([("A", "B"), ("B", "C"), ("A", "D"), ("D", "C")]) - assert dag_lat4.minimal_dseparator("A", "C") is None + assert dag_lat4.minimal_dseparator(["A"], ["C"]) is None # Test adjacent nodes (should raise error) dag5 = DAG() dag5.add_edges_from([("A", "B")]) - with pytest.raises(ValueError, match="No possible separators because start and end are adjacent"): - dag5.minimal_dseparator("A", "B") + with pytest.raises(ValueError, match="No possible separators because A and B are adjacent"): + dag5.minimal_dseparator(["A"], ["B"]) \ No newline at end of file From 2d57b60f3ab3fbf811937ded5805340ab66bcec1 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Fri, 12 Sep 2025 00:24:56 +0530 Subject: [PATCH 5/6] wasm test refactors --- wasm_bindings/js/tests/test-dag.js | 8 ++++---- wasm_bindings/src/lib.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/wasm_bindings/js/tests/test-dag.js b/wasm_bindings/js/tests/test-dag.js index dffee2e..fdb0d7d 100644 --- a/wasm_bindings/js/tests/test-dag.js +++ b/wasm_bindings/js/tests/test-dag.js @@ -43,11 +43,11 @@ describe("DAG wasm (CJS)", () => { expect(areNeighbors).toBe(false); }); - it("should compute minimal d-separator (simple)", () => { + it("should compute minimal d-separator (simple)", () => { const dag = new cg.DAG(); dag.addEdge("A", "B"); dag.addEdge("B", "C"); - const sep = dag.minimalDseparator("A", "C"); + const sep = dag.minimalDseparator(["A"], ["C"]); expect(sep.sort()).toEqual(["B"]); }); @@ -58,7 +58,7 @@ describe("DAG wasm (CJS)", () => { dag.addEdge("C", "D"); dag.addEdge("A", "E"); dag.addEdge("E", "D"); - const sep = dag.minimalDseparator("A", "D"); + const sep = dag.minimalDseparator(["A"], ["D"]); expect(sep.sort()).toEqual(["C", "E"]); }); @@ -69,7 +69,7 @@ describe("DAG wasm (CJS)", () => { dag.addNode("C", false); dag.addEdge("A", "B"); dag.addEdge("B", "C"); - const sep = dag.minimalDseparator("A", "C"); + const sep = dag.minimalDseparator(["A"], ["C"]); expect(sep).toBeNull(); }); diff --git a/wasm_bindings/src/lib.rs b/wasm_bindings/src/lib.rs index e1155f2..42b2b67 100644 --- a/wasm_bindings/src/lib.rs +++ b/wasm_bindings/src/lib.rs @@ -98,8 +98,8 @@ impl DAG { // In RustDAG impl #[wasm_bindgen(js_name = minimalDseparator, catch)] - pub fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> Result { - let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + pub fn minimal_dseparator(&self, starts: Vec, ends: Vec, include_latents: Option) -> Result { + let result = self.inner.minimal_dseparator(starts, ends, include_latents.unwrap_or(false)) .map_err(|e| JsValue::from_str(&e))?; match result { From 5b7f0b7a1995b798c765c81e686440c4e0f17f2e Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 27 Sep 2025 14:29:09 +0530 Subject: [PATCH 6/6] trying to fix windows build failure --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 701d641..06e5300 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,6 +75,10 @@ jobs: run: | sudo apt-get install -y libcurl4-openssl-dev libssl-dev Rscript -e 'install.packages(c("devtools", "rextendr"), repos="https://cloud.r-project.org")' + + - name: Add Rust Windows target + if: runner.os == 'Windows' + run: rustup target add x86_64-pc-windows-gnu # Build all bindings (Rust, Python, WASM, R) - name: Build all components