Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def run_tests(

return results

def save_results(self, results: List[CausalTestResult], output_path: str = None) -> None:
def save_results(self, results: List[CausalTestResult], output_path: str = None) -> list:
"""Save test results to JSON file in the expected format."""
if output_path is None:
output_path = self.paths.output_path
Expand All @@ -438,36 +438,60 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

# Combine test configs with their results
json_results = []
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
# Determine if test failed based on expected vs actual effect
test_passed = (
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
)
result_index = 0

for test_config in test_configs["tests"]:

output = {
# Create a base output first of common entries
base_output = {
"name": test_config["name"],
"estimate_type": test_config["estimate_type"],
"effect": test_config.get("effect", "direct"),
"treatment_variable": test_config["treatment_variable"],
"expected_effect": test_config["expected_effect"],
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
"alpha": test_config.get("alpha", 0.05),
"skip": test_config.get("skip", False),
"passed": test_passed,
"result": (
{
"treatment": result.estimator.base_test_case.treatment_variable.name,
"outcome": result.estimator.base_test_case.outcome_variable.name,
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
}
| result.effect_estimate.to_dict()
| (result.adequacy.to_dict() if result.adequacy else {})
if result.effect_estimate
else {"error": result.error_message}
),
}
if test_config.get("skip", False):
# Include those skipped test entry without execution results
output = {
**base_output,
"formula": test_config.get("formula"),
"skip": True,
"passed": None,
"result": {
"status": "skipped",
"reason": "Test marked as skip:true in the causal test config file.",
},
}
else:
# Add executed test with actual results
test_case = self.test_cases[result_index]
result = results[result_index]
result_index += 1

test_passed = (
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
)

output = {
**base_output,
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
"skip": False,
"passed": test_passed,
"result": (
{
"treatment": result.estimator.base_test_case.treatment_variable.name,
"outcome": result.estimator.base_test_case.outcome_variable.name,
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
}
| result.effect_estimate.to_dict()
| (result.adequacy.to_dict() if result.adequacy else {})
if result.effect_estimate
else {"status": "error", "reason": result.error_message}
),
}

json_results.append(output)

# Save to file
Expand Down
97 changes: 49 additions & 48 deletions tests/main_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import tempfile
import os
from unittest.mock import patch


import shutil
import json
import pandas as pd
Expand Down Expand Up @@ -137,36 +135,45 @@ def test_unloaded_tests(self):
def test_unloaded_tests_batches(self):
framework = CausalTestingFramework(self.paths)
with self.assertRaises(ValueError) as e:
# Need the next because of the yield statement in run_tests_in_batches
next(framework.run_tests_in_batches())
self.assertEqual("No tests loaded. Call load_tests() first.", str(e.exception))

def test_ctf(self):
framework = CausalTestingFramework(self.paths)
framework.setup()

# Load and run tests
framework.load_tests()
results = framework.run_tests()

# Save results
framework.save_results(results)
json_results = framework.save_results(results)

with open(self.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

tests_passed = [
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
]
self.assertEqual(len(json_results), len(test_configs["tests"]))

self.assertEqual(tests_passed, [True])
result_index = 0
for i, test_config in enumerate(test_configs["tests"]):
result = json_results[i]

if test_config.get("skip", False):
self.assertEqual(result["skip"], True)
self.assertEqual(result["passed"], None)
self.assertEqual(result["result"]["status"], "skipped")
else:
test_case = framework.test_cases[result_index]
framework_result = results[result_index]
result_index += 1

test_passed = (
test_case.expected_causal_effect.apply(framework_result)
if framework_result.effect_estimate is not None else False
)
self.assertEqual(result["passed"], test_passed)

def test_ctf_batches(self):
framework = CausalTestingFramework(self.paths)
framework.setup()

# Load and run tests
framework.load_tests()

output_files = []
Expand All @@ -177,19 +184,18 @@ def test_ctf_batches(self):
output_files.append(temp_file_path)
del results

# Now stitch the results together from the temporary files
all_results = []
for file_path in output_files:
with open(file_path, "r", encoding="utf-8") as f:
all_results.extend(json.load(f))

self.assertEqual([result["passed"] for result in all_results], [True])
executed_results = [result for result in all_results if not result.get("skip", False)]
self.assertEqual([result["passed"] for result in executed_results], [True])

def test_ctf_exception(self):
framework = CausalTestingFramework(self.paths, query="test_input < 0")
framework.setup()

# Load and run tests
framework.load_tests()
with self.assertRaises(ValueError):
framework.run_tests()
Expand All @@ -198,7 +204,6 @@ def test_ctf_batches_exception_silent(self):
framework = CausalTestingFramework(self.paths, query="test_input < 0")
framework.setup()

# Load and run tests
framework.load_tests()

output_files = []
Expand All @@ -209,55 +214,48 @@ def test_ctf_batches_exception_silent(self):
output_files.append(temp_file_path)
del results

# Now stitch the results together from the temporary files
all_results = []
for file_path in output_files:
with open(file_path, "r", encoding="utf-8") as f:
all_results.extend(json.load(f))

self.assertEqual([result["passed"] for result in all_results], [False])
self.assertIsNotNone([result.get("error") for result in all_results])
executed_results = [result for result in all_results if not result.get("skip", False)]
self.assertEqual([result["passed"] for result in executed_results], [False])
self.assertIsNotNone([result.get("error") for result in executed_results])

def test_ctf_exception_silent(self):
framework = CausalTestingFramework(self.paths, query="test_input < 0")
framework.setup()

# Load and run tests
framework.load_tests()

results = framework.run_tests(silent=True)
json_results = framework.save_results(results)

with open(self.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

tests_passed = [
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
]
non_skipped_configs = [t for t in test_configs["tests"] if not t.get("skip", False)]
non_skipped_results = [r for r in json_results if not r.get("skip", False)]

self.assertEqual(tests_passed, [False])
self.assertEqual(
[result.error_message for result in results],
["zero-size array to reduction operation maximum which has no identity"],
)
self.assertEqual(len(non_skipped_results), len(non_skipped_configs))

for result in non_skipped_results:
self.assertEqual(result["passed"], False)

def test_ctf_batches_exception(self):
framework = CausalTestingFramework(self.paths, query="test_input < 0")
framework.setup()

# Load and run tests
framework.load_tests()
with self.assertRaises(ValueError):
next(framework.run_tests_in_batches())

def test_ctf_batches_matches_run_tests(self):
# Run the tests normally
framework = CausalTestingFramework(self.paths)
framework.setup()
framework.load_tests()
normale_results = framework.run_tests()
normal_results = framework.run_tests()

# Run the tests in batches
output_files = []
with tempfile.TemporaryDirectory() as tmpdir:
for i, results in enumerate(framework.run_tests_in_batches()):
Expand All @@ -266,24 +264,24 @@ def test_ctf_batches_matches_run_tests(self):
output_files.append(temp_file_path)
del results

# Now stitch the results together from the temporary files
all_results = []
for file_path in output_files:
with open(file_path, "r", encoding="utf-8") as f:
all_results.extend(json.load(f))

with tempfile.TemporaryDirectory() as tmpdir:
normal_output = os.path.join(tmpdir, f"normal.json")
framework.save_results(normale_results, normal_output)
normal_output = os.path.join(tmpdir, "normal.json")
framework.save_results(normal_results, normal_output)
with open(normal_output) as f:
normal_results = json.load(f)
normal_json = json.load(f)

batch_output = os.path.join(tmpdir, f"batch.json")
batch_output = os.path.join(tmpdir, "batch.json")
with open(batch_output, "w") as f:
json.dump(all_results, f)
with open(batch_output) as f:
batch_results = json.load(f)
self.assertEqual(normal_results, batch_results)
batch_json = json.load(f)

self.assertEqual(normal_json, batch_json)

def test_global_query(self):
framework = CausalTestingFramework(self.paths)
Expand All @@ -308,7 +306,6 @@ def test_global_query(self):
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())

query_framework.create_variables()

self.assertIsNotNone(query_framework.scenario)

def test_test_specific_query(self):
Expand Down Expand Up @@ -383,7 +380,8 @@ def test_parse_args_adequacy(self):
main()
with open(self.output_path.parent / "main.json") as f:
log = json.load(f)
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
executed_tests = [test for test in log if not test.get("skip", False)]
assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests)

def test_parse_args_adequacy_batches(self):
with patch(
Expand All @@ -407,7 +405,8 @@ def test_parse_args_adequacy_batches(self):
main()
with open(self.output_path.parent / "main.json") as f:
log = json.load(f)
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
executed_tests = [test for test in log if not test.get("skip", False)]
assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests)

def test_parse_args_bootstrap_size(self):
with patch(
Expand All @@ -430,7 +429,8 @@ def test_parse_args_bootstrap_size(self):
main()
with open(self.output_path.parent / "main.json") as f:
log = json.load(f)
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
executed_tests = [test for test in log if not test.get("skip", False)]
assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests)

def test_parse_args_bootstrap_size_explicit_adequacy(self):
with patch(
Expand All @@ -454,7 +454,8 @@ def test_parse_args_bootstrap_size_explicit_adequacy(self):
main()
with open(self.output_path.parent / "main.json") as f:
log = json.load(f)
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
executed_tests = [test for test in log if not test.get("skip", False)]
assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests)

def test_parse_args_batches(self):
with patch(
Expand Down Expand Up @@ -517,4 +518,4 @@ def test_parse_args_generation_non_default(self):

def tearDown(self):
if self.output_path.parent.exists():
shutil.rmtree(self.output_path.parent)
shutil.rmtree(self.output_path.parent)