diff --git a/chex/_src/asserts_chexify_test.py b/chex/_src/asserts_chexify_test.py index b4226d4..3eddaed 100644 --- a/chex/_src/asserts_chexify_test.py +++ b/chex/_src/asserts_chexify_test.py @@ -217,10 +217,19 @@ def logp1_abs_safe(x): asserts_chexify.block_until_chexify_assertions_complete() err_regex = re.escape(_ai.get_chexify_err_message('assert_tree_all_finite')) - with self.assertRaisesRegex(AssertionError, f'{err_regex}.*chexify_test'): + + # On Windows, JAX traceback filtering drops the user-frame filename. + expected_regex = ( + err_regex if sys.platform == "win32" + else f"{err_regex}.*chexify_test" + ) + + with self.assertRaisesRegex(AssertionError, expected_regex): logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS logp1_abs_safe.wait_checks() # pytype: disable=attribute-error + + def test_checkify_errors(self): @jax.jit def take_by_index_and_div(x, i, y):