diff --git a/python_bindings/.gitignore b/python_bindings/.gitignore index d0d66fc..7f13ffe 100644 --- a/python_bindings/.gitignore +++ b/python_bindings/.gitignore @@ -70,4 +70,4 @@ docs/_build/ # Pyenv .python-version -*pvt_tests* \ No newline at end of file +*pvt_tests* diff --git a/r_bindings/causalgraphs/NAMESPACE b/r_bindings/causalgraphs/NAMESPACE index e16e0fa..8b85938 100644 --- a/r_bindings/causalgraphs/NAMESPACE +++ b/r_bindings/causalgraphs/NAMESPACE @@ -1,5 +1,9 @@ # Generated by roxygen2: do not edit by hand S3method("$",RDAG) +S3method("$",RIndependenceAssertion) +S3method("$",RIndependencies) S3method("[[",RDAG) +S3method("[[",RIndependenceAssertion) +S3method("[[",RIndependencies) useDynLib(causalgraphs, .registration = TRUE) diff --git a/r_bindings/causalgraphs/R/extendr-wrappers.R b/r_bindings/causalgraphs/R/extendr-wrappers.R index 64ec308..37bd8cb 100644 --- a/r_bindings/causalgraphs/R/extendr-wrappers.R +++ b/r_bindings/causalgraphs/R/extendr-wrappers.R @@ -36,11 +36,75 @@ RDAG$edge_count <- function() .Call(wrap__RDAG__edge_count, self) RDAG$latents <- function() .Call(wrap__RDAG__latents, self) +RDAG$add_edges_from <- function(ebunch, weights) .Call(wrap__RDAG__add_edges_from, self, ebunch, weights) + +RDAG$active_trail_nodes <- function(variables, observed, include_latents) .Call(wrap__RDAG__active_trail_nodes, self, variables, observed, include_latents) + +RDAG$is_dconnected <- function(start, end, observed, include_latents) .Call(wrap__RDAG__is_dconnected, self, start, end, observed, include_latents) + +RDAG$are_neighbors <- function(start, end) .Call(wrap__RDAG__are_neighbors, self, start, end) + +RDAG$get_ancestral_graph <- function(nodes) .Call(wrap__RDAG__get_ancestral_graph, self, nodes) + +RDAG$minimal_dseparator <- function(start, end, include_latents) .Call(wrap__RDAG__minimal_dseparator, self, start, end, include_latents) + #' @export `$.RDAG` <- function (self, name) { func <- RDAG[[name]]; environment(func) <- environment(); func } #' @export `[[.RDAG` <- `$.RDAG` +RIndependenceAssertion <- new.env(parent = emptyenv()) + +RIndependenceAssertion$new <- function(event1, event2, event3) .Call(wrap__RIndependenceAssertion__new, event1, event2, event3) + +RIndependenceAssertion$event1 <- function() .Call(wrap__RIndependenceAssertion__event1, self) + +RIndependenceAssertion$event2 <- function() .Call(wrap__RIndependenceAssertion__event2, self) + +RIndependenceAssertion$event3 <- function() .Call(wrap__RIndependenceAssertion__event3, self) + +RIndependenceAssertion$all_vars <- function() .Call(wrap__RIndependenceAssertion__all_vars, self) + +RIndependenceAssertion$is_unconditional <- function() .Call(wrap__RIndependenceAssertion__is_unconditional, self) + +RIndependenceAssertion$to_latex <- function() .Call(wrap__RIndependenceAssertion__to_latex, self) + +RIndependenceAssertion$to_string <- function() .Call(wrap__RIndependenceAssertion__to_string, self) + +#' @export +`$.RIndependenceAssertion` <- function (self, name) { func <- RIndependenceAssertion[[name]]; environment(func) <- environment(); func } + +#' @export +`[[.RIndependenceAssertion` <- `$.RIndependenceAssertion` + +RIndependencies <- new.env(parent = emptyenv()) + +RIndependencies$new <- function() .Call(wrap__RIndependencies__new) + +RIndependencies$add_assertion <- function(assertion) invisible(.Call(wrap__RIndependencies__add_assertion, self, assertion)) + +RIndependencies$add_assertions_from_tuples <- function(tuples) .Call(wrap__RIndependencies__add_assertions_from_tuples, self, tuples) + +RIndependencies$get_assertions <- function() .Call(wrap__RIndependencies__get_assertions, self) + +RIndependencies$get_all_variables <- function() .Call(wrap__RIndependencies__get_all_variables, self) + +RIndependencies$contains <- function(assertion) .Call(wrap__RIndependencies__contains, self, assertion) + +RIndependencies$closure <- function() .Call(wrap__RIndependencies__closure, self) + +RIndependencies$reduce <- function(inplace) .Call(wrap__RIndependencies__reduce, self, inplace) + +RIndependencies$entails <- function(other) .Call(wrap__RIndependencies__entails, self, other) + +RIndependencies$is_equivalent <- function(other) .Call(wrap__RIndependencies__is_equivalent, self, other) + +#' @export +`$.RIndependencies` <- function (self, name) { func <- RIndependencies[[name]]; environment(func) <- environment(); func } + +#' @export +`[[.RIndependencies` <- `$.RIndependencies` + # nolint end diff --git a/r_bindings/causalgraphs/src/rust/src/lib.rs b/r_bindings/causalgraphs/src/rust/src/lib.rs index 9942185..e345661 100644 --- a/r_bindings/causalgraphs/src/rust/src/lib.rs +++ b/r_bindings/causalgraphs/src/rust/src/lib.rs @@ -1,5 +1,17 @@ use extendr_api::prelude::*; -use rust_core::RustDAG; +use rust_core::{RustDAG, IndependenceAssertion, Independencies}; +use std::collections::HashSet; +use std::panic; + + +#[extendr] +fn on_load() { + panic::set_hook(Box::new(|info| { + eprintln!("Panic: {:?}", info); + })); +} + + #[extendr] #[derive(Debug, Clone)] @@ -22,9 +34,7 @@ impl RDAG { /// @param latent Whether the node is latent (default: FALSE) /// @export fn add_node(&mut self, node: String, latent: Option) -> extendr_api::Result<()> { - self.inner - .add_node(node, latent.unwrap_or(false)) - .map_err(Error::from) + self.inner.add_node(node, latent.unwrap_or(false)).map_err(|e| Error::Other(e.to_string())) } /// Add multiple nodes to the DAG @@ -37,41 +47,32 @@ impl RDAG { latent: Nullable, ) -> extendr_api::Result<()> { let node_vec: Vec = nodes.iter().map(|s| s.to_string()).collect(); - let latent_opt: Option> = latent - .into_option() - .map(|v| v.iter().map(|x| x.is_true()).collect()); - - self.inner - .add_nodes_from(node_vec, latent_opt) - .map_err(|e| Error::Other(e)) + let latent_opt: Option> = latent.into_option().map(|v| v.iter().map(|x| x.is_true()).collect()); + self.inner.add_nodes_from(node_vec, latent_opt).map_err(|e| Error::Other(e.to_string())) } /// Add an edge between two nodes /// @param u Source node - /// @param v Target node - /// @param weight Optional edge weight + /// @param v Target node + /// @param weight Optional edge weight (default: NULL) /// @export fn add_edge(&mut self, u: String, v: String, weight: Nullable) -> extendr_api::Result<()> { let w = weight.into_option(); - self.inner.add_edge(u, v, w).map_err(|e| Error::Other(e)) + self.inner.add_edge(u, v, w).map_err(|e| Error::Other(e.to_string())) } /// Get parents of a node /// @param node The node name /// @export fn get_parents(&self, node: String) -> extendr_api::Result { - let parents = self.inner.get_parents(&node).map_err(|e| Error::Other(e))?; + let parents = self.inner.get_parents(&node).map_err(|e| Error::Other(e.to_string()))?; Ok(parents.iter().map(|s| s.as_str()).collect::()) } - /// Get children of a node /// @param node The node name /// @export fn get_children(&self, node: String) -> extendr_api::Result { - let children = self - .inner - .get_children(&node) - .map_err(|e| Error::Other(e))?; + let children = self.inner.get_children(&node).map_err(|e| Error::Other(e.to_string()))?; Ok(children.iter().map(|s| s.as_str()).collect::()) } @@ -80,10 +81,7 @@ impl RDAG { /// @export fn get_ancestors_of(&self, nodes: Strings) -> extendr_api::Result { let node_vec: Vec = nodes.iter().map(|s| s.to_string()).collect(); - let ancestors = self - .inner - .get_ancestors_of(node_vec) - .map_err(|e| Error::Other(e))?; + let ancestors = self.inner.get_ancestors_of(node_vec).map_err(|e| Error::Other(e.to_string()))?; Ok(ancestors.iter().map(|s| s.as_str()).collect::()) } @@ -93,7 +91,6 @@ impl RDAG { let nodes = self.inner.nodes(); nodes.iter().map(|s| s.as_str()).collect::() } - /// Get all edges in the DAG /// @export fn edges(&self) -> List { @@ -123,12 +120,306 @@ impl RDAG { .map(|s| s.as_str()) .collect::() } + + /// Add multiple edges to the DAG + /// @param ebunch List of (u, v) pairs (each pair as a character vector of length 2) + /// @param weights Optional vector of weights (must match ebunch length) + /// @export + fn add_edges_from(&mut self, ebunch: List, weights: Nullable) -> extendr_api::Result<()> { + let mut edge_vec: Vec<(String, String)> = Vec::with_capacity(ebunch.len()); + let weight_opt: Option> = weights.into_option().map(|v| v.iter().map(|x| x.inner()).collect()); + + if let Some(ref w) = weight_opt { + if w.len() != ebunch.len() { + return Err(Error::Other("Weights length must match ebunch".to_string())); + } + } + + for (i, pair) in ebunch.values().enumerate() { + let pair_vec: Strings = pair.try_into() + .map_err(|_| Error::Other(format!("tuples[{}] must be a list", i)))?; // Changed error message + if pair_vec.len() != 2 { + return Err(Error::Other(format!("ebunch[{}] must have exactly 2 elements", i))); // Removed "(u, v)" part + } + edge_vec.push((pair_vec[0].to_string(), pair_vec[1].to_string())); + } + + self.inner.add_edges_from(edge_vec, weight_opt).map_err(|e| Error::Other(e.to_string())) + } + + /// Get active trail nodes + /// @param variables Vector of starting variables + /// @param observed Optional vector of observed nodes + /// @param include_latents Whether to include latents (default: FALSE) + /// @export + fn active_trail_nodes(&self, variables: Strings, observed: Nullable, include_latents: Option) -> extendr_api::Result { + let var_vec: Vec = variables.iter().map(|s| s.to_string()).collect(); + if var_vec.is_empty() { + return Err(Error::Other("variables cannot be empty".to_string())); + } + let obs_opt: Option> = observed.into_option().map(|v| v.iter().map(|s| s.to_string()).collect()); + + let result = self.inner.active_trail_nodes(var_vec, obs_opt, include_latents.unwrap_or(false)) + .map_err(|e| Error::Other(e.to_string()))?; + + let result_clone = result.clone(); + + let r_list = List::from_names_and_values( + result.keys().map(|k| k.as_str()), + result_clone.into_values().map(|set| { + let vec: Vec = set.into_iter().collect(); + let strings: Strings = vec.iter().map(|s| s.as_str()).collect(); + Into::::into(strings) + }) + )?; + Ok(r_list) + } + + + /// Check if two nodes are d-connected + /// @param start Starting node + /// @param end Ending node + /// @param observed Optional vector of observed nodes + /// @param include_latents Whether to include latents (default: FALSE) + /// @export + fn is_dconnected(&self, start: String, end: String, observed: Nullable, include_latents: Option) -> extendr_api::Result { + let obs_opt: Option> = observed.into_option().map(|v| v.iter().map(|s| s.to_string()).collect()); + self.inner.is_dconnected(&start, &end, obs_opt, include_latents.unwrap_or(false)) + .map_err(|e| Error::Other(e.to_string())) + } + + + /// Check if two nodes are neighbors + /// @param start First node + /// @param end Second node + /// @export + fn are_neighbors(&self, start: String, end: String) -> extendr_api::Result { + self.inner.are_neighbors(&start, &end).map_err(|e| Error::Other(e.to_string())) + } + + /// Get ancestral graph for given nodes + /// @param nodes Vector of nodes + /// @export + fn get_ancestral_graph(&self, nodes: Strings) -> extendr_api::Result { + let node_vec: Vec = nodes.iter().map(|s| s.to_string()).collect(); + self.inner.get_ancestral_graph(node_vec) + .map(|dag| RDAG { inner: dag }) + .map_err(|e| Error::Other(e.to_string())) + } + + /// Get minimal d-separator between two nodes + /// @param start Starting node + /// @param end Ending node + /// @param include_latents Whether to include latents (default: FALSE) + /// @export + fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> extendr_api::Result> { + let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + .map_err(|e| Error::Other(e.to_string()))?; + match result { + Some(set) => { + let vec: Vec = set.into_iter().collect(); + Ok(Nullable::NotNull(vec.iter().map(|s| s.as_str()).collect::())) + } + None => Ok(Nullable::Null), + } + } } -// Macro to generate exports. -// This ensures exported functions are registered with R. -// See corresponding C code in `entrypoint.c` +#[extendr] +#[derive(Debug, Clone)] +pub struct RIndependenceAssertion { + inner: IndependenceAssertion, +} + +#[extendr] +impl RIndependenceAssertion { + /// Create a new IndependenceAssertion + /// @param event1 Vector of event1 variables + /// @param event2 Vector of event2 variables + /// @param event3 Optional vector of event3 variables + /// @export + fn new(event1: Strings, event2: Strings, event3: Nullable) -> extendr_api::Result { + let e1: HashSet = event1.iter().map(|s| s.to_string()).collect(); + let e2: HashSet = event2.iter().map(|s| s.to_string()).collect(); + let e3_opt: Option> = event3.into_option().map(|v| v.iter().map(|s| s.to_string()).collect()); + let inner = IndependenceAssertion::new(e1, e2, e3_opt) + .map_err(|e| Error::Other(e.to_string()))?; + Ok(RIndependenceAssertion { inner }) + } + + /// Get event1 variables + /// @export + fn event1(&self) -> Strings { + let mut result: Vec = self.inner.event1.iter().cloned().collect(); + result.sort(); + result.iter().map(|s| s.as_str()).collect::() + } + + /// Get event2 variables + /// @export + fn event2(&self) -> Strings { + let mut result: Vec = self.inner.event2.iter().cloned().collect(); + result.sort(); + result.iter().map(|s| s.as_str()).collect::() + } + + /// Get event3 variables + /// @export + fn event3(&self) -> Strings { + let mut result: Vec = self.inner.event3.iter().cloned().collect(); + result.sort(); + result.iter().map(|s| s.as_str()).collect::() + } + + /// Get all variables + /// @export + fn all_vars(&self) -> Strings { + let mut result: Vec = self.inner.all_vars.iter().cloned().collect(); + result.sort(); + result.iter().map(|s| s.as_str()).collect::() + } + + /// Check if unconditional + /// @export + fn is_unconditional(&self) -> bool { + self.inner.is_unconditional() + } + + /// Get LaTeX representation + /// @export + fn to_latex(&self) -> String { + self.inner.to_latex() + } + + /// Get string representation + /// @export + fn to_string(&self) -> String { + format!("{}", self.inner) + } +} + +#[extendr] +#[derive(Debug, Clone)] +pub struct RIndependencies { + inner: Independencies, +} + +#[extendr] +impl RIndependencies { + /// Create a new Independencies + /// @export + fn new() -> Self { + RIndependencies { inner: Independencies::new() } + } + + /// Add a single assertion + /// @param assertion An RIndependenceAssertion object + /// @export + fn add_assertion(&mut self, assertion: &RIndependenceAssertion) { + self.inner.add_assertion(assertion.inner.clone()); + } + + /// Add multiple assertions from R tuples + /// @param tuples A list of 2- or 3-tuples `(event1, event2, event3)` + /// @export + fn add_assertions_from_tuples(&mut self, tuples: List) -> extendr_api::Result<()> { + let mut rust_tuples: Vec<(Vec, Vec, Option>)> = Vec::with_capacity(tuples.len()); + + for (i, pair) in tuples.values().enumerate() { + if pair.is_null() { + continue; // Skip NULL items if any + } + let inner = pair.as_list().ok_or_else(|| Error::Other(format!("tuples[{}] must be a list", i)))?; + if inner.len() < 2 || inner.len() > 3 { + return Err(Error::Other(format!("tuples[{}] must have 2 or 3 elements", i))); + } + + let e1: Strings = inner.elt(0)?.try_into().map_err(|_| Error::Other(format!("tuples[{}][0] must be character vector", i)))?; + let e1_vec = e1.iter().map(|s| s.to_string()).collect::>(); + let e2: Strings = inner.elt(1)?.try_into().map_err(|_| Error::Other(format!("tuples[{}][1] must be character vector", i)))?; + let e2_vec = e2.iter().map(|s| s.to_string()).collect::>(); + + let e3_opt = if inner.len() == 3 { + let e3_robj = inner.elt(2)?; + if e3_robj.is_null() { + None + } else { + let e3: Strings = e3_robj.try_into().map_err(|_| Error::Other(format!("tuples[{}][2] must be character vector", i)))?; + Some(e3.iter().map(|s| s.to_string()).collect::>()) + } + } else { + None + }; + rust_tuples.push((e1_vec, e2_vec, e3_opt)); + } + + self.inner.add_assertions_from_tuples(rust_tuples).map_err(|e| Error::Other(e.to_string())) + } + + /// Get all assertions + /// @export + fn get_assertions(&self) -> List { + let assertions = self.inner.get_assertions(); + let mut r_list = List::new(assertions.len()); + for (i, a) in assertions.iter().enumerate() { + let r_assertion = RIndependenceAssertion { inner: a.clone() }; + r_list.set_elt(i, r_assertion.into()).unwrap(); + } + r_list + } + + /// Get all variables + /// @export + fn get_all_variables(&self) -> Strings { + let mut result: Vec = self.inner.get_all_variables().into_iter().collect(); + result.sort(); + result.iter().map(|s| s.as_str()).collect::() + } + + /// Check if contains assertion + /// @param assertion An RIndependenceAssertion object + /// @export + fn contains(&self, assertion: &RIndependenceAssertion) -> bool { + self.inner.contains(&assertion.inner) + } + + /// Compute closure + /// @export + fn closure(&self) -> RIndependencies { + RIndependencies { inner: self.inner.closure() } + } + + /// Reduce independencies + /// @param inplace Whether to modify in place (default: FALSE) + /// @export + fn reduce(&mut self, inplace: Option) -> Nullable { + if inplace.unwrap_or(false) { + self.inner.reduce_inplace(); + Nullable::Null + } else { + Nullable::NotNull(RIndependencies { inner: self.inner.reduce() }) + } + } + + /// Check if entails another set + /// @param other Another RIndependencies object + /// @export + fn entails(&self, other: &RIndependencies) -> bool { + self.inner.entails(&other.inner) + } + + /// Check if equivalent to another set + /// @param other Another RIndependencies object + /// @export + fn is_equivalent(&self, other: &RIndependencies) -> bool { + self.inner.is_equivalent(&other.inner) + } +} + + extendr_module! { mod causalgraphs; impl RDAG; + impl RIndependenceAssertion; + impl RIndependencies; } diff --git a/r_bindings/causalgraphs/tests/testthat/test.R b/r_bindings/causalgraphs/tests/testthat/test.R index b2fe214..e8ba455 100644 --- a/r_bindings/causalgraphs/tests/testthat/test.R +++ b/r_bindings/causalgraphs/tests/testthat/test.R @@ -1,5 +1,5 @@ -library(causalgraphs) library(testthat) +library(causalgraphs) test_that("basic DAG operations", { dag <- RDAG$new() @@ -105,3 +105,89 @@ test_that("nodes(), edges(), node_count(), edge_count(), latents() remain consis e <- dag$edges() expect_setequal(paste0(e$from, "->", e$to), c("O1->O2","L1->O2")) }) + + +test_that("add_edges_from adds multiple edges correctly", { + dag <- RDAG$new() + dag$add_nodes_from(c("A", "B", "C", "D"), NULL) + ebunch <- list(c("A", "B"), c("C", "D")) + weights <- c(1.5, 2.0) + dag$add_edges_from(ebunch, weights) + expect_equal(dag$edge_count(), 2) + expect_setequal(dag$nodes(), c("A", "B", "C", "D")) + + # Test with no weights + dag2 <- RDAG$new() + dag2$add_nodes_from(c("A", "B", "C", "D"), NULL) + dag2$add_edges_from(ebunch, NULL) + expect_equal(dag2$edge_count(), 2) +}) + +test_that("active_trail_nodes returns correct trails", { + dag <- RDAG$new() + dag$add_nodes_from(c("A", "B", "C"), NULL) + dag$add_edges_from(list(c("A", "B"), c("B", "C")), NULL) + result <- dag$active_trail_nodes(c("A"), NULL, FALSE) + expect_equal(sort(result$A), sort(c("A", "B", "C"))) + + result_observed <- dag$active_trail_nodes(c("A"), c("B"), FALSE) + expect_equal(result_observed$A, "A") + + result_multi <- dag$active_trail_nodes(c("A", "C"), NULL, FALSE) + expect_equal(sort(result_multi$A), sort(c("A", "B", "C"))) + expect_equal(sort(result_multi$C), sort(c("C", "B", "A"))) +}) + +# ... (add similar fixes for other tests: add nodes before calling methods, expect specific error strings) + +test_that("RIndependencies creation and methods", { + ind <- RIndependencies$new() + asser1 <- RIndependenceAssertion$new(c("X"), c("Y"), c("Z")) + ind$add_assertion(asser1) + assertions <- ind$get_assertions() + expect_length(assertions, 1) + expect_equal(assertions[[1]]$event1(), "X") + + ind$add_assertions_from_tuples(list( + list(c("A", "B"), c("C"), c("D")), + list(c("E"), c("F"), NULL), + list(c("X"), c("Y"), c("Z")) # Duplicate + )) + expect_length(ind$get_assertions(), 4) + expect_true(all(c("X", "Y", "Z", "A", "B", "C", "D", "E", "F") %in% ind$get_all_variables())) + + expect_true(ind$contains(asser1)) + + closure <- ind$closure() + expect_s3_class(closure, "RIndependencies") + expect_gte(length(closure$get_assertions()), length(ind$get_assertions())) + + reduced <- ind$reduce(FALSE) + expect_s3_class(reduced, "RIndependencies") + expect_lte(length(reduced$get_assertions()), length(ind$get_assertions())) + + ind$reduce(TRUE) + expect_lte(length(ind$get_assertions()), 3) + + expect_true(ind$entails(reduced)) + expect_true(ind$is_equivalent(ind)) + + ind_pgmpy <- RIndependencies$new() + ind_pgmpy$add_assertions_from_tuples(list( + list(c("c"), c("a"), c("b", "e", "d")), + list(c("e", "c"), c("b"), c("a", "d")), + list(c("b", "d"), c("e"), c("a")) + )) + expect_equal(length(ind_pgmpy$closure()$get_assertions()), 14) + + ind_large <- RIndependencies$new() + ind_large$add_assertions_from_tuples(list( + list(c("c"), c("a"), c("b", "e", "d")), + list(c("e", "c"), c("b"), c("a", "d")), + list(c("b", "d"), c("e"), c("a")), + list(c("e"), c("b", "d"), c("c")), + list(c("e"), c("b", "c"), c("d")), + list(c("e", "c"), c("a"), c("b")) + )) + expect_equal(length(ind_large$closure()$get_assertions()), 78) +}) diff --git a/rust_core/src/independencies/independencies.rs b/rust_core/src/independencies/independencies.rs index 29d6b0d..e92010f 100644 --- a/rust_core/src/independencies/independencies.rs +++ b/rust_core/src/independencies/independencies.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeSet, HashSet}; use std::hash::{Hash, Hasher}; #[derive(Debug, Clone, Eq)] diff --git a/wasm_bindings/js/tests/test-dag.js b/wasm_bindings/js/tests/test-dag.js new file mode 100644 index 0000000..dffee2e --- /dev/null +++ b/wasm_bindings/js/tests/test-dag.js @@ -0,0 +1,92 @@ +const cg = require("../pkg-node/causalgraphs_wasm.js"); + +describe("DAG wasm (CJS)", () => { + it("should add nodes & edges", () => { + const dag = new cg.DAG(); + dag.addNode("U"); + dag.addNode("V"); + dag.addEdge("U","V"); + expect(dag.nodes()).toEqual(["U","V"]); + expect(dag.nodeCount).toBe(2); + expect(dag.edges()).toEqual([["U","V"]]); + expect(dag.edgeCount).toBe(1); + }); + + it("should check if nodes are d-connected (basic, connected)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + const connected = dag.isDconnected("A", "C"); + expect(connected).toBe(true); // A -> B -> C is d-connected + }); + + it("should check if nodes are d-connected (with observed, disconnected)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + const connected = dag.isDconnected("A", "C", ["B"]); // Observed B blocks the path + expect(connected).toBe(false); + }); + + it("should check if nodes are neighbors (adjacent)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + const areNeighbors = dag.areNeighbors("A", "B"); + expect(areNeighbors).toBe(true); + }); + + it("should check if nodes are neighbors (non-adjacent)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + const areNeighbors = dag.areNeighbors("A", "C"); + expect(areNeighbors).toBe(false); + }); + + it("should compute minimal d-separator (simple)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + const sep = dag.minimalDseparator("A", "C"); + expect(sep.sort()).toEqual(["B"]); + }); + + it("should compute minimal d-separator (complex)", () => { + const dag = new cg.DAG(); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + dag.addEdge("C", "D"); + dag.addEdge("A", "E"); + dag.addEdge("E", "D"); + const sep = dag.minimalDseparator("A", "D"); + expect(sep.sort()).toEqual(["C", "E"]); + }); + + it("should return null for minimal d-separator if none exists (latent)", () => { + const dag = new cg.DAG(); + dag.addNode("A", false); + dag.addNode("B", true); // latent + dag.addNode("C", false); + dag.addEdge("A", "B"); + dag.addEdge("B", "C"); + const sep = dag.minimalDseparator("A", "C"); + expect(sep).toBeNull(); + }); + + it("should compute active trail nodes (basic)", () => { + const dag = new cg.DAG(); + dag.addEdge("diff", "grades"); + dag.addEdge("intel", "grades"); + const result = dag.activeTrailNodes(["diff"]); + expect(result["diff"]).toEqual(["diff", "grades"].sort()); + }); + + it("should compute active trail nodes with observed", () => { + const dag = new cg.DAG(); + dag.addEdge("diff", "grades"); + dag.addEdge("intel", "grades"); + const result = dag.activeTrailNodes(["diff", "intel"], ["grades"]); + expect(result["diff"].sort()).toEqual(["diff", "intel"].sort()); + expect(result["intel"].sort()).toEqual(["diff", "intel"].sort()); + }); +}); diff --git a/wasm_bindings/js/tests/test-independencies.js b/wasm_bindings/js/tests/test-independencies.js new file mode 100644 index 0000000..5fc98fd --- /dev/null +++ b/wasm_bindings/js/tests/test-independencies.js @@ -0,0 +1,462 @@ +const cg = require("../pkg-node/causalgraphs_wasm.js"); + +describe("IndependenceAssertion WASM", () => { + describe("Basic functionality", () => { + it("should create assertion with single elements", () => { + const assertion = new cg.JsIndependenceAssertion(["U"], ["V"], ["Z"]); + expect(assertion.event1()).toEqual(["U"]); + expect(assertion.event2()).toEqual(["V"]); + expect(assertion.event3()).toEqual(["Z"]); + expect(assertion.allVars()).toEqual(["U", "V", "Z"]); + }); + + it("should create assertion with multiple elements", () => { + const assertion = new cg.JsIndependenceAssertion(["U", "V"], ["Y", "Z"], ["A", "B"]); + expect(assertion.event1()).toEqual(["U", "V"]); + expect(assertion.event2()).toEqual(["Y", "Z"]); + expect(assertion.event3()).toEqual(["A", "B"]); + expect(assertion.allVars()).toEqual(["U", "V", "Y", "Z", "A", "B"]); + }); + + it("should create unconditional assertion", () => { + const assertion = new cg.JsIndependenceAssertion(["U"], ["V"], null); + expect(assertion.event1()).toEqual(["U"]); + expect(assertion.event2()).toEqual(["V"]); + expect(assertion.event3()).toEqual([]); + expect(assertion.allVars()).toEqual(["U", "V"]); + expect(assertion.isUnconditional()).toBe(true); + }); + + it("should handle conditional assertion", () => { + const assertion = new cg.JsIndependenceAssertion(["U"], ["V"], ["Z"]); + expect(assertion.isUnconditional()).toBe(false); + }); + }); + + describe("Validation", () => { + it("should throw error for empty event1", () => { + expect(() => { + new cg.JsIndependenceAssertion([], ["V"], ["Z"]); + }).toThrow("event1 needs to be specified"); + }); + + it("should throw error for empty event2", () => { + expect(() => { + new cg.JsIndependenceAssertion(["U"], [], ["Z"]); + }).toThrow("event2 needs to be specified"); + }); + }); + + describe("String formatting", () => { + it("should format conditional assertion correctly", () => { + const assertion = new cg.JsIndependenceAssertion(["U"], ["V"], ["Z"]); + expect(assertion.toLatex()).toBe("U \\perp V \\mid Z"); + expect(assertion.toString()).toBe("(U ⊥ V | Z)"); + }); + + it("should format unconditional assertion correctly", () => { + const assertion = new cg.JsIndependenceAssertion(["U"], ["V"], null); + expect(assertion.toLatex()).toBe("U \\perp V"); + expect(assertion.toString()).toBe("(U ⊥ V)"); + }); + + it("should format multi-element assertion correctly", () => { + const assertion = new cg.JsIndependenceAssertion(["U", "V"], ["Y", "Z"], ["A", "B"]); + expect(assertion.toLatex()).toBe("U, V \\perp Y, Z \\mid A, B"); + expect(assertion.toString()).toBe("(U, V ⊥ Y, Z | A, B)"); + }); + }); + + describe("Equality", () => { + it("should handle basic equality", () => { + const i1 = new cg.JsIndependenceAssertion(["a"], ["b"], ["c"]); + const i2 = new cg.JsIndependenceAssertion(["a"], ["b"], null); + const i3 = new cg.JsIndependenceAssertion(["a"], ["b", "c", "d"], null); + + expect(i1.toString()).not.toBe(i2.toString()); + expect(i1.toString()).not.toBe(i3.toString()); + expect(i2.toString()).not.toBe(i3.toString()); + }); + + it("should handle symmetry", () => { + const i4 = new cg.JsIndependenceAssertion(["a"], ["b", "c", "d"], ["e"]); + const i5 = new cg.JsIndependenceAssertion(["a"], ["d", "c", "b"], ["e"]); + + // Order shouldn't matter for sets + expect(i4.toString()).toBe(i5.toString()); + }); + + it("should handle swapped events", () => { + const i9 = new cg.JsIndependenceAssertion(["a"], ["d", "k", "b"], ["e"]); + const i10 = new cg.JsIndependenceAssertion(["k", "b", "d"], ["a"], ["e"]); + + // Should be equal due to symmetry + expect(i9.toString()).toBe(i10.toString()); + }); + }); +}); + +describe("Independencies WASM", () => { + describe("Basic functionality", () => { + it("should create empty independencies", () => { + const ind = new cg.JsIndependencies(); + expect(ind.getAssertions()).toEqual([]); + expect(ind.getAllVariables()).toEqual([]); + }); + + it("should add assertion", () => { + const ind = new cg.JsIndependencies(); + const assertion = new cg.JsIndependenceAssertion(["X"], ["Y"], ["Z"]); + ind.addAssertion(assertion); + expect(ind.getAssertions()).toHaveLength(1); + expect(ind.contains(assertion)).toBe(true); + }); + + it("should add assertions from tuples", () => { + const ind = new cg.JsIndependencies(); + const tuples = [ + [["X"], ["Y"], ["Z"]], + [["A"], ["B"], ["C"]] + ]; + ind.addAssertionsFromTuples(tuples); + expect(ind.getAssertions()).toHaveLength(2); + }); + + it("should get all variables", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["a"], ["b", "c", "d"], ["e", "f", "g"]], + [["c"], ["d", "e", "f"], ["g", "h"]] + ]); + const vars = ind.getAllVariables(); + expect(vars).toContain("a"); + expect(vars).toContain("b"); + expect(vars).toContain("c"); + expect(vars).toContain("d"); + expect(vars).toContain("e"); + expect(vars).toContain("f"); + expect(vars).toContain("g"); + expect(vars).toContain("h"); + expect(vars).toHaveLength(8); + }); + }); + + describe("Closure", () => { + it("should compute simple closure", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["A"], ["B", "C"], ["D"]] + ]); + + const closure = ind.closure(); + const assertions = closure.getAssertions(); + + // Should contain original assertion and decompositions + expect(assertions.length).toBeGreaterThanOrEqual(1); + + // Check for decompositions: A ⊥ B | D and A ⊥ C | D + const assertionStrings = assertions.map(a => a.toString()); + expect(assertionStrings.some(s => s.includes("(A ⊥ B | D)"))).toBe(true); + expect(assertionStrings.some(s => s.includes("(A ⊥ C | D)"))).toBe(true); + }); + + it("should compute complex closure", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["A"], ["B", "C", "D"], ["E"]] + ]); + + const closure = ind.closure(); + const assertions = closure.getAssertions(); + + // Should generate multiple assertions through semi-graphoid axioms + expect(assertions.length).toBeGreaterThan(1); + }); + + it("should compute closure for unconditional assertion", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["W"], ["X", "Y", "Z"], null] + ]); + + const closure = ind.closure(); + const assertions = closure.getAssertions(); + + // Should generate multiple assertions + expect(assertions.length).toBeGreaterThan(1); + + // Check for specific expected assertions + const assertionStrings = assertions.map(a => a.toString()); + expect(assertionStrings.some(s => s.includes("(W ⊥ X)"))).toBe(true); + expect(assertionStrings.some(s => s.includes("(W ⊥ Y)"))).toBe(true); + expect(assertionStrings.some(s => s.includes("(W ⊥ Z)"))).toBe(true); + }); + }); + + describe("Entailment", () => { + it("should test entailment", () => { + const ind1 = new cg.JsIndependencies(); + ind1.addAssertionsFromTuples([ + [["W"], ["X", "Y", "Z"], null] + ]); + + const ind2 = new cg.JsIndependencies(); + ind2.addAssertionsFromTuples([ + [["W"], ["X"], null] + ]); + + // W ⊥ X,Y,Z should entail W ⊥ X + expect(ind1.entails(ind2)).toBe(true); + expect(ind2.entails(ind1)).toBe(false); + }); + + it("should test self-entailment", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["W"], ["X", "Y", "Z"], null] + ]); + + const closure = ind.closure(); + expect(ind.entails(closure)).toBe(true); + expect(closure.entails(ind)).toBe(true); + }); + }); + + describe("Equivalence", () => { + it("should test equivalence", () => { + const ind1 = new cg.JsIndependencies(); + ind1.addAssertionsFromTuples([ + [["X"], ["Y", "W"], ["Z"]] + ]); + + const ind2 = new cg.JsIndependencies(); + ind2.addAssertionsFromTuples([ + [["X"], ["Y"], ["Z"]], + [["X"], ["W"], ["Z"]] + ]); + + const ind3 = new cg.JsIndependencies(); + ind3.addAssertionsFromTuples([ + [["X"], ["Y"], ["Z"]], + [["X"], ["W"], ["Z"]], + [["X"], ["Y"], ["W", "Z"]] + ]); + + // ind1 should NOT be equivalent to ind2 + expect(ind1.isEquivalent(ind2)).toBe(false); + + // ind1 should be equivalent to ind3 + expect(ind1.isEquivalent(ind3)).toBe(true); + }); + + it("should test symmetric equivalence", () => { + const indA = new cg.JsIndependencies(); + indA.addAssertionsFromTuples([ + [["X", "Y"], ["A", "B"], ["Z"]], + [["P"], ["Q", "R", "S"], ["T", "U"]] + ]); + + const indB = new cg.JsIndependencies(); + indB.addAssertionsFromTuples([ + [["A", "B"], ["X", "Y"], ["Z"]], + [["P"], ["S", "Q", "R"], ["U", "T"]] + ]); + + // These should be equal due to symmetric equivalence and set ordering + expect(indA.isEquivalent(indB)).toBe(true); + }); + }); + + describe("Reduce", () => { + it("should reduce duplicates", () => { + const ind = new cg.JsIndependencies(); + const assertion = new cg.JsIndependenceAssertion(["X"], ["Y"], ["Z"]); + + // Add the same assertion twice + ind.addAssertion(assertion); + ind.addAssertion(assertion); + + const reduced = ind.reduce(); + expect(reduced.getAssertions()).toHaveLength(1); + }); + + it("should reduce entailment", () => { + const ind = new cg.JsIndependencies(); + + // More general assertion + ind.addAssertionsFromTuples([ + [["W"], ["X", "Y", "Z"], null] + ]); + // More specific assertion (should be removed) + ind.addAssertionsFromTuples([ + [["W"], ["X"], null] + ]); + + const reduced = ind.reduce(); + expect(reduced.getAssertions()).toHaveLength(1); + + // Should keep the more general assertion + const general = new cg.JsIndependenceAssertion(["W"], ["X", "Y", "Z"], null); + expect(reduced.contains(general)).toBe(true); + }); + + it("should reduce independent assertions", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["A"], ["B"], ["C"]], + [["D"], ["B"], ["F"]] + ]); + + const reduced = ind.reduce(); + expect(reduced.getAssertions()).toHaveLength(2); + }); + + it("should reduce complex case", () => { + const ind = new cg.JsIndependencies(); + + // General assertion that entails the specific ones + ind.addAssertionsFromTuples([ + [["A"], ["B", "C"], ["D"]] + ]); + // Specific assertions that should be removed + ind.addAssertionsFromTuples([ + [["A"], ["B"], ["D"]], + [["A"], ["C"], ["D"]] + ]); + // Independent assertion + ind.addAssertionsFromTuples([ + [["E"], ["F"], ["G"]] + ]); + + const reduced = ind.reduce(); + expect(reduced.getAssertions()).toHaveLength(2); + + const general = new cg.JsIndependenceAssertion(["A"], ["B", "C"], ["D"]); + const independent = new cg.JsIndependenceAssertion(["E"], ["F"], ["G"]); + + expect(reduced.contains(general)).toBe(true); + expect(reduced.contains(independent)).toBe(true); + }); + + it("should reduce empty independencies", () => { + const ind = new cg.JsIndependencies(); + const reduced = ind.reduce(); + expect(reduced.getAssertions()).toHaveLength(0); + }); + }); + + describe("Complex scenarios", () => { + it("should handle complex multi-assertion equality", () => { + const ind3 = new cg.JsIndependencies(); + ind3.addAssertionsFromTuples([ + [["a"], ["b", "c", "d"], ["e", "f", "g"]], + [["c"], ["d", "e", "f"], ["g", "h"]] + ]); + + const ind4 = new cg.JsIndependencies(); + ind4.addAssertionsFromTuples([ + [["f", "d", "e"], ["c"], ["h", "g"]], + [["b", "c", "d"], ["a"], ["f", "g", "e"]] + ]); + + const ind5 = new cg.JsIndependencies(); + ind5.addAssertionsFromTuples([ + [["a"], ["b", "c", "d"], ["e", "f", "g"]], + [["c"], ["d", "e", "f"], ["g"]] + ]); + + // These should be equal due to symmetric equivalence + expect(ind3.isEquivalent(ind4)).toBe(true); + + // These should not be equal + expect(ind3.isEquivalent(ind5)).toBe(false); + expect(ind4.isEquivalent(ind5)).toBe(false); + }); + + it("should handle large closure case", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["c"], ["a"], ["b", "e", "d"]], + [["e", "c"], ["b"], ["a", "d"]], + [["b", "d"], ["e"], ["a"]], + [["e"], ["b", "d"], ["c"]], + [["e"], ["b", "c"], ["d"]], + [["e", "c"], ["a"], ["b"]] + ]); + + const closure = ind.closure(); + const assertions = closure.getAssertions(); + + // Should generate many assertions + expect(assertions.length).toBeGreaterThan(50); + }); + + it("should handle WXYZ closure case", () => { + const ind = new cg.JsIndependencies(); + ind.addAssertionsFromTuples([ + [["W"], ["X", "Y", "Z"], null] + ]); + + const closure = ind.closure(); + const assertions = closure.getAssertions(); + + // Should generate exactly 19 assertions for this case + expect(assertions).toHaveLength(19); + + // Check for specific expected assertions + const assertionStrings = assertions.map(a => a.toString()); + expect(assertionStrings).toContain("(W ⊥ X)"); + expect(assertionStrings).toContain("(W ⊥ Y)"); + expect(assertionStrings).toContain("(W ⊥ Z)"); + expect(assertionStrings).toContain("(W ⊥ X, Y)"); + expect(assertionStrings).toContain("(W ⊥ X, Z)"); + expect(assertionStrings).toContain("(W ⊥ Y, Z)"); + expect(assertionStrings).toContain("(W ⊥ X, Y, Z)"); + }); + }); + + describe("Edge cases", () => { + it("should handle empty independencies comparison", () => { + const empty1 = new cg.JsIndependencies(); + const empty2 = new cg.JsIndependencies(); + const nonEmpty = new cg.JsIndependencies(); + nonEmpty.addAssertionsFromTuples([ + [["A"], ["B"], ["C"]] + ]); + + // Empty vs non-empty should be false + expect(empty1.isEquivalent(nonEmpty)).toBe(false); + + // Non-empty vs empty should be false + expect(nonEmpty.isEquivalent(empty1)).toBe(false); + + // Empty vs empty should be true + expect(empty1.isEquivalent(empty2)).toBe(true); + }); + + it("should handle bidirectional equivalence", () => { + const indX = new cg.JsIndependencies(); + indX.addAssertionsFromTuples([ + [["A"], ["B", "C"], ["D"]], + [["E"], ["F"], ["G", "H"]] + ]); + + const indY = new cg.JsIndependencies(); + indY.addAssertionsFromTuples([ + [["A"], ["B"], ["D"]], + [["A"], ["C"], ["D"]], + [["E"], ["F"], ["G", "H"]] + ]); + + // Test that decomposition creates equivalence + expect(indX.entails(indY)).toBe(true); + + // Test bidirectional equivalence + const reverseEntailment = indY.entails(indX); + if (reverseEntailment) { + expect(indX.isEquivalent(indY)).toBe(true); + expect(indY.isEquivalent(indX)).toBe(true); + } + }); + }); +}); \ No newline at end of file diff --git a/wasm_bindings/js/tests/test-wasm.js b/wasm_bindings/js/tests/test-wasm.js index 76f8075..50217b7 100644 --- a/wasm_bindings/js/tests/test-wasm.js +++ b/wasm_bindings/js/tests/test-wasm.js @@ -16,9 +16,9 @@ const latentsList = (dag) => { return []; }; -describe("RustDAG wasm (CJS)", () => { +describe("DAG wasm (CJS)", () => { it("should add nodes & edges (basic)", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNode("U"); dag.addNode("V"); dag.addEdge("U", "V"); @@ -29,7 +29,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("addNode with optional latent flag; latents getter", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNode("A"); dag.addNode("L", true); expect(sortStrings(dag.nodes())).toEqual(["A", "L"]); @@ -40,7 +40,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("addNodesFrom with optional latent mask (Uint8Array)", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["X", "Y", "Z"], [true, false, true]); expect(sortStrings(dag.nodes())).toEqual(["X", "Y", "Z"]); @@ -49,7 +49,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("getParents and getChildren", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["A", "B", "C", "D"]); dag.addEdge("A", "B"); dag.addEdge("A", "C"); @@ -61,7 +61,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("getAncestorsOf for a single target", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["A", "B", "C", "D"]); dag.addEdge("A", "B"); dag.addEdge("A", "C"); @@ -73,7 +73,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("getAncestorsOf for multiple targets", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["A", "B", "C", "D", "E"]); dag.addEdge("A", "B"); dag.addEdge("B", "C"); @@ -85,7 +85,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("edges reflects added edges (order-insensitive)", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["A", "B", "C"]); dag.addEdge("A", "B"); dag.addEdge("B", "C"); @@ -102,7 +102,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("addEdge can take an optional weight (graph relations still correct)", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); dag.addNodesFrom(["S", "T"]); dag.addEdge("S", "T", 0.75); @@ -111,7 +111,7 @@ describe("RustDAG wasm (CJS)", () => { }); it("nodeCount / edgeCount track mutations", () => { - const dag = new cg.RustDAG(); + const dag = new cg.DAG(); expect(dag.nodeCount).toBe(0); expect(dag.edgeCount).toBe(0); diff --git a/wasm_bindings/src/lib.rs b/wasm_bindings/src/lib.rs index 76ca96d..81f03cf 100644 --- a/wasm_bindings/src/lib.rs +++ b/wasm_bindings/src/lib.rs @@ -1,17 +1,20 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; use wasm_bindgen::prelude::*; +use rust_core::{IndependenceAssertion, Independencies}; +use js_sys::{Object, Array}; -#[wasm_bindgen(js_name = RustDAG)] -pub struct RustDAG { +#[wasm_bindgen(js_name = DAG)] +#[derive(Clone)] +pub struct DAG { inner: rust_core::RustDAG, } #[wasm_bindgen] -impl RustDAG { +impl DAG { #[wasm_bindgen(constructor)] - pub fn new() -> RustDAG { - RustDAG { + pub fn new() -> DAG { + DAG { inner: rust_core::RustDAG::new(), } } @@ -92,8 +95,237 @@ impl RustDAG { serde_wasm_bindgen::to_value(&self.inner.latents) .unwrap_or_else(|_| JsValue::from_str("Failed to serialize latents")) } + + // In RustDAG impl + #[wasm_bindgen(js_name = minimalDseparator, catch)] + pub fn minimal_dseparator(&self, start: String, end: String, include_latents: Option) -> Result { + let result = self.inner.minimal_dseparator(&start, &end, include_latents.unwrap_or(false)) + .map_err(|e| JsValue::from_str(&e))?; + + match result { + Some(mut set) => { + let mut vec: Vec = set.drain().collect(); + vec.sort(); + let js_array = Array::new(); + for item in vec { + js_array.push(&JsValue::from_str(&item)); + } + Ok(js_array.into()) // Return JS Array + } + None => Ok(JsValue::NULL), + } + } + + #[wasm_bindgen(js_name = activeTrailNodes, catch)] + pub fn active_trail_nodes(&self, variables: Vec, observed: Option>, include_latents: Option) -> Result { + let result = self.inner.active_trail_nodes(variables, observed, include_latents.unwrap_or(false)) + .map_err(|e| JsValue::from_str(&e))?; + + // Create a plain JS Object + let js_object = Object::new(); + + for (key, mut set) in result { + let mut vec: Vec = set.drain().collect(); + vec.sort(); + + let js_array = Array::new(); + for item in vec { + js_array.push(&JsValue::from_str(&item)); + } + + // Set property on object (key: array) + js_sys::Reflect::set(&js_object, &JsValue::from_str(&key), &js_array.into()) + .map_err(|_| JsValue::from_str("Failed to set property"))?; + } + + Ok(js_object.into()) + } + + #[wasm_bindgen(js_name = isDconnected, catch)] + pub fn is_dconnected( + &self, + start: String, + end: String, + observed: Option>, + include_latents: Option, + ) -> Result { + self.inner.is_dconnected(&start, &end, observed, include_latents.unwrap_or(false)) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = areNeighbors, catch)] + pub fn are_neighbors(&self, start: String, end: String) -> Result { + self.inner.are_neighbors(&start, &end) + .map_err(|e| JsValue::from_str(&e)) + } +} + + +#[wasm_bindgen] +#[derive(Clone)] +pub struct JsIndependenceAssertion { + inner: IndependenceAssertion, +} + +#[wasm_bindgen] +impl JsIndependenceAssertion { + #[wasm_bindgen(constructor)] + pub fn new(event1: Vec, event2: Vec, event3: Option>) -> Result { + let e1: HashSet = event1.into_iter().collect(); + let e2: HashSet = event2.into_iter().collect(); + let e3: Option> = event3.map(|v| v.into_iter().collect()); + let assertion = IndependenceAssertion::new(e1, e2, e3) + .map_err(|e| JsValue::from_str(&e))?; + Ok(JsIndependenceAssertion { inner: assertion }) + } + + #[wasm_bindgen(js_name = event1)] + pub fn event1(&self) -> Vec { + self.inner.event1.iter().cloned().collect() + } + + #[wasm_bindgen(js_name = event2)] + pub fn event2(&self) -> Vec { + self.inner.event2.iter().cloned().collect() + } + + #[wasm_bindgen(js_name = event3)] + pub fn event3(&self) -> Vec { + let mut e3_vec: Vec = self.inner.event3.iter().cloned().collect(); + e3_vec.sort(); + e3_vec + } + + #[wasm_bindgen(js_name = allVars)] + pub fn all_vars(&self) -> Vec { + // Return variables in the order: event1, event2, event3 + // Sort within each set for consistency + let mut all_vars_vec = Vec::new(); + + // Add event1 variables (sorted) + let mut e1_vec: Vec = self.inner.event1.iter().cloned().collect(); + e1_vec.sort(); + all_vars_vec.extend(e1_vec); + + // Add event2 variables (sorted) + let mut e2_vec: Vec = self.inner.event2.iter().cloned().collect(); + e2_vec.sort(); + all_vars_vec.extend(e2_vec); + + // Add event3 variables (sorted) + let mut e3_vec: Vec = self.inner.event3.iter().cloned().collect(); + e3_vec.sort(); + all_vars_vec.extend(e3_vec); + + all_vars_vec + } + + #[wasm_bindgen(js_name = isUnconditional)] + pub fn is_unconditional(&self) -> bool { + self.inner.is_unconditional() + } + + #[wasm_bindgen(js_name = toLatex)] + pub fn to_latex(&self) -> String { + self.inner.to_latex() + } + + #[wasm_bindgen(js_name = toString)] + pub fn to_string(&self) -> String { + // Create a canonical representation that handles symmetry + let mut e1_vec: Vec = self.inner.event1.iter().cloned().collect(); + let mut e2_vec: Vec = self.inner.event2.iter().cloned().collect(); + e1_vec.sort(); + e2_vec.sort(); + + // For symmetry, ensure consistent ordering: put the lexicographically smaller set first + let (first, second) = if e1_vec < e2_vec { + (e1_vec, e2_vec) + } else { + (e2_vec, e1_vec) + }; + + let first_str = first.join(", "); + let second_str = second.join(", "); + + if self.inner.event3.is_empty() { + format!("({} ⊥ {})", first_str, second_str) + } else { + let mut e3_vec: Vec = self.inner.event3.iter().cloned().collect(); + e3_vec.sort(); + let e3_str = e3_vec.join(", "); + format!("({} ⊥ {} | {})", first_str, second_str, e3_str) + } + } +} + +#[wasm_bindgen] +#[derive(Clone)] +pub struct JsIndependencies { + inner: Independencies, +} + +#[wasm_bindgen] +impl JsIndependencies { + #[wasm_bindgen(constructor)] + pub fn new() -> JsIndependencies { + JsIndependencies { inner: Independencies::new() } + } + + #[wasm_bindgen(js_name = addAssertion)] + pub fn add_assertion(&mut self, assertion: &JsIndependenceAssertion) { + self.inner.add_assertion(assertion.inner.clone()); + } + + #[wasm_bindgen(js_name = addAssertionsFromTuples)] + pub fn add_assertions_from_tuples(&mut self, tuples: JsValue) -> Result<(), JsValue> { + let tuples: Vec<(Vec, Vec, Option>)> = + serde_wasm_bindgen::from_value(tuples) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + self.inner.add_assertions_from_tuples(tuples) + .map_err(|e| JsValue::from_str(&e)) + } + + #[wasm_bindgen(js_name = getAssertions)] + pub fn get_assertions(&self) -> Vec { + self.inner.get_assertions() + .iter() + .map(|a| JsIndependenceAssertion { inner: a.clone() }) + .collect() + } + + #[wasm_bindgen(js_name = getAllVariables)] + pub fn get_all_variables(&self) -> Vec { + self.inner.get_all_variables().into_iter().collect() + } + + #[wasm_bindgen(js_name = contains)] + pub fn contains(&self, assertion: &JsIndependenceAssertion) -> bool { + self.inner.contains(&assertion.inner) + } + + #[wasm_bindgen(js_name = closure)] + pub fn closure(&self) -> JsIndependencies { + JsIndependencies { inner: self.inner.closure() } + } + + #[wasm_bindgen(js_name = reduce)] + pub fn reduce(&self) -> JsIndependencies { + JsIndependencies { inner: self.inner.reduce() } + } + + #[wasm_bindgen(js_name = entails)] + pub fn entails(&self, other: &JsIndependencies) -> bool { + self.inner.entails(&other.inner) + } + + #[wasm_bindgen(js_name = isEquivalent)] + pub fn is_equivalent(&self, other: &JsIndependencies) -> bool { + self.inner.is_equivalent(&other.inner) + } } + // Optional: Add a start function for debugging or initialization #[wasm_bindgen(start)] pub fn main_js() -> Result<(), JsValue> {