From 999c1d0f67eef3996eb344e029ae3f3edc14b229 Mon Sep 17 00:00:00 2001 From: pandaayushman Date: Sun, 21 Dec 2025 13:05:26 +0530 Subject: [PATCH] Make chexify docstring test robust to platform-specific tracebacks --- chex/_src/asserts_chexify_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/chex/_src/asserts_chexify_test.py b/chex/_src/asserts_chexify_test.py index b4226d49..3eddaed5 100644 --- a/chex/_src/asserts_chexify_test.py +++ b/chex/_src/asserts_chexify_test.py @@ -217,10 +217,19 @@ def logp1_abs_safe(x): asserts_chexify.block_until_chexify_assertions_complete() err_regex = re.escape(_ai.get_chexify_err_message('assert_tree_all_finite')) - with self.assertRaisesRegex(AssertionError, f'{err_regex}.*chexify_test'): + + # On Windows, JAX traceback filtering drops the user-frame filename. + expected_regex = ( + err_regex if sys.platform == "win32" + else f"{err_regex}.*chexify_test" + ) + + with self.assertRaisesRegex(AssertionError, expected_regex): logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS logp1_abs_safe.wait_checks() # pytype: disable=attribute-error + + def test_checkify_errors(self): @jax.jit def take_by_index_and_div(x, i, y):