diff --git a/causal_testing/main.py b/causal_testing/main.py index 27492ade..1c27114d 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -425,7 +425,7 @@ def run_tests( return results - def save_results(self, results: List[CausalTestResult], output_path: str = None) -> None: + def save_results(self, results: List[CausalTestResult], output_path: str = None) -> list: """Save test results to JSON file in the expected format.""" if output_path is None: output_path = self.paths.output_path @@ -438,36 +438,60 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None) with open(self.paths.test_config_path, "r", encoding="utf-8") as f: test_configs = json.load(f) - # Combine test configs with their results json_results = [] - for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results): - # Determine if test failed based on expected vs actual effect - test_passed = ( - test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False - ) + result_index = 0 + + for test_config in test_configs["tests"]: - output = { + # Create a base output first of common entries + base_output = { "name": test_config["name"], "estimate_type": test_config["estimate_type"], "effect": test_config.get("effect", "direct"), "treatment_variable": test_config["treatment_variable"], "expected_effect": test_config["expected_effect"], - "formula": result.estimator.formula if hasattr(result.estimator, "formula") else None, "alpha": test_config.get("alpha", 0.05), - "skip": test_config.get("skip", False), - "passed": test_passed, - "result": ( - { - "treatment": result.estimator.base_test_case.treatment_variable.name, - "outcome": result.estimator.base_test_case.outcome_variable.name, - "adjustment_set": list(result.adjustment_set) if result.adjustment_set else [], - } - | result.effect_estimate.to_dict() - | (result.adequacy.to_dict() if result.adequacy else {}) - if result.effect_estimate - else {"error": result.error_message} - ), } + if test_config.get("skip", False): + # Include those skipped test entry without execution results + output = { + **base_output, + "formula": test_config.get("formula"), + "skip": True, + "passed": None, + "result": { + "status": "skipped", + "reason": "Test marked as skip:true in the causal test config file.", + }, + } + else: + # Add executed test with actual results + test_case = self.test_cases[result_index] + result = results[result_index] + result_index += 1 + + test_passed = ( + test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False + ) + + output = { + **base_output, + "formula": result.estimator.formula if hasattr(result.estimator, "formula") else None, + "skip": False, + "passed": test_passed, + "result": ( + { + "treatment": result.estimator.base_test_case.treatment_variable.name, + "outcome": result.estimator.base_test_case.outcome_variable.name, + "adjustment_set": list(result.adjustment_set) if result.adjustment_set else [], + } + | result.effect_estimate.to_dict() + | (result.adequacy.to_dict() if result.adequacy else {}) + if result.effect_estimate + else {"status": "error", "reason": result.error_message} + ), + } + json_results.append(output) # Save to file diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index ef01f580..232cf0e7 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -3,8 +3,6 @@ import tempfile import os from unittest.mock import patch - - import shutil import json import pandas as pd @@ -137,7 +135,6 @@ def test_unloaded_tests(self): def test_unloaded_tests_batches(self): framework = CausalTestingFramework(self.paths) with self.assertRaises(ValueError) as e: - # Need the next because of the yield statement in run_tests_in_batches next(framework.run_tests_in_batches()) self.assertEqual("No tests loaded. Call load_tests() first.", str(e.exception)) @@ -145,28 +142,38 @@ def test_ctf(self): framework = CausalTestingFramework(self.paths) framework.setup() - # Load and run tests framework.load_tests() results = framework.run_tests() - - # Save results - framework.save_results(results) + json_results = framework.save_results(results) with open(self.test_config_path, "r", encoding="utf-8") as f: test_configs = json.load(f) - tests_passed = [ - test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False - for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results) - ] + self.assertEqual(len(json_results), len(test_configs["tests"])) - self.assertEqual(tests_passed, [True]) + result_index = 0 + for i, test_config in enumerate(test_configs["tests"]): + result = json_results[i] + + if test_config.get("skip", False): + self.assertEqual(result["skip"], True) + self.assertEqual(result["passed"], None) + self.assertEqual(result["result"]["status"], "skipped") + else: + test_case = framework.test_cases[result_index] + framework_result = results[result_index] + result_index += 1 + + test_passed = ( + test_case.expected_causal_effect.apply(framework_result) + if framework_result.effect_estimate is not None else False + ) + self.assertEqual(result["passed"], test_passed) def test_ctf_batches(self): framework = CausalTestingFramework(self.paths) framework.setup() - # Load and run tests framework.load_tests() output_files = [] @@ -177,19 +184,18 @@ def test_ctf_batches(self): output_files.append(temp_file_path) del results - # Now stitch the results together from the temporary files all_results = [] for file_path in output_files: with open(file_path, "r", encoding="utf-8") as f: all_results.extend(json.load(f)) - self.assertEqual([result["passed"] for result in all_results], [True]) + executed_results = [result for result in all_results if not result.get("skip", False)] + self.assertEqual([result["passed"] for result in executed_results], [True]) def test_ctf_exception(self): framework = CausalTestingFramework(self.paths, query="test_input < 0") framework.setup() - # Load and run tests framework.load_tests() with self.assertRaises(ValueError): framework.run_tests() @@ -198,7 +204,6 @@ def test_ctf_batches_exception_silent(self): framework = CausalTestingFramework(self.paths, query="test_input < 0") framework.setup() - # Load and run tests framework.load_tests() output_files = [] @@ -209,55 +214,48 @@ def test_ctf_batches_exception_silent(self): output_files.append(temp_file_path) del results - # Now stitch the results together from the temporary files all_results = [] for file_path in output_files: with open(file_path, "r", encoding="utf-8") as f: all_results.extend(json.load(f)) - self.assertEqual([result["passed"] for result in all_results], [False]) - self.assertIsNotNone([result.get("error") for result in all_results]) + executed_results = [result for result in all_results if not result.get("skip", False)] + self.assertEqual([result["passed"] for result in executed_results], [False]) + self.assertIsNotNone([result.get("error") for result in executed_results]) def test_ctf_exception_silent(self): framework = CausalTestingFramework(self.paths, query="test_input < 0") framework.setup() - # Load and run tests framework.load_tests() - results = framework.run_tests(silent=True) + json_results = framework.save_results(results) with open(self.test_config_path, "r", encoding="utf-8") as f: test_configs = json.load(f) - tests_passed = [ - test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False - for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results) - ] + non_skipped_configs = [t for t in test_configs["tests"] if not t.get("skip", False)] + non_skipped_results = [r for r in json_results if not r.get("skip", False)] - self.assertEqual(tests_passed, [False]) - self.assertEqual( - [result.error_message for result in results], - ["zero-size array to reduction operation maximum which has no identity"], - ) + self.assertEqual(len(non_skipped_results), len(non_skipped_configs)) + + for result in non_skipped_results: + self.assertEqual(result["passed"], False) def test_ctf_batches_exception(self): framework = CausalTestingFramework(self.paths, query="test_input < 0") framework.setup() - # Load and run tests framework.load_tests() with self.assertRaises(ValueError): next(framework.run_tests_in_batches()) def test_ctf_batches_matches_run_tests(self): - # Run the tests normally framework = CausalTestingFramework(self.paths) framework.setup() framework.load_tests() - normale_results = framework.run_tests() + normal_results = framework.run_tests() - # Run the tests in batches output_files = [] with tempfile.TemporaryDirectory() as tmpdir: for i, results in enumerate(framework.run_tests_in_batches()): @@ -266,24 +264,24 @@ def test_ctf_batches_matches_run_tests(self): output_files.append(temp_file_path) del results - # Now stitch the results together from the temporary files all_results = [] for file_path in output_files: with open(file_path, "r", encoding="utf-8") as f: all_results.extend(json.load(f)) with tempfile.TemporaryDirectory() as tmpdir: - normal_output = os.path.join(tmpdir, f"normal.json") - framework.save_results(normale_results, normal_output) + normal_output = os.path.join(tmpdir, "normal.json") + framework.save_results(normal_results, normal_output) with open(normal_output) as f: - normal_results = json.load(f) + normal_json = json.load(f) - batch_output = os.path.join(tmpdir, f"batch.json") + batch_output = os.path.join(tmpdir, "batch.json") with open(batch_output, "w") as f: json.dump(all_results, f) with open(batch_output) as f: - batch_results = json.load(f) - self.assertEqual(normal_results, batch_results) + batch_json = json.load(f) + + self.assertEqual(normal_json, batch_json) def test_global_query(self): framework = CausalTestingFramework(self.paths) @@ -308,7 +306,6 @@ def test_global_query(self): self.assertTrue((causal_test.estimator.df["test_input"] > 0).all()) query_framework.create_variables() - self.assertIsNotNone(query_framework.scenario) def test_test_specific_query(self): @@ -383,7 +380,8 @@ def test_parse_args_adequacy(self): main() with open(self.output_path.parent / "main.json") as f: log = json.load(f) - assert all(test["result"]["bootstrap_size"] == 100 for test in log) + executed_tests = [test for test in log if not test.get("skip", False)] + assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests) def test_parse_args_adequacy_batches(self): with patch( @@ -407,7 +405,8 @@ def test_parse_args_adequacy_batches(self): main() with open(self.output_path.parent / "main.json") as f: log = json.load(f) - assert all(test["result"]["bootstrap_size"] == 100 for test in log) + executed_tests = [test for test in log if not test.get("skip", False)] + assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests) def test_parse_args_bootstrap_size(self): with patch( @@ -430,7 +429,8 @@ def test_parse_args_bootstrap_size(self): main() with open(self.output_path.parent / "main.json") as f: log = json.load(f) - assert all(test["result"]["bootstrap_size"] == 50 for test in log) + executed_tests = [test for test in log if not test.get("skip", False)] + assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests) def test_parse_args_bootstrap_size_explicit_adequacy(self): with patch( @@ -454,7 +454,8 @@ def test_parse_args_bootstrap_size_explicit_adequacy(self): main() with open(self.output_path.parent / "main.json") as f: log = json.load(f) - assert all(test["result"]["bootstrap_size"] == 50 for test in log) + executed_tests = [test for test in log if not test.get("skip", False)] + assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests) def test_parse_args_batches(self): with patch( @@ -517,4 +518,4 @@ def test_parse_args_generation_non_default(self): def tearDown(self): if self.output_path.parent.exists(): - shutil.rmtree(self.output_path.parent) + shutil.rmtree(self.output_path.parent) \ No newline at end of file