From 9ca361f1e5fb0de40d601df0fd058bee099a9aca Mon Sep 17 00:00:00 2001 From: Michal Kazmierski Date: Thu, 29 Jan 2026 01:38:03 -0800 Subject: [PATCH] Add strict mode support in chex.assert_trees_all_close. This aligns the behaviour with chex.assert_trees_all_equal. PiperOrigin-RevId: 862622074 --- chex/_src/asserts.py | 31 ++++++++++++++++++++++++------- chex/_src/asserts_test.py | 18 ++++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index ce5a69e..1643d53 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1641,9 +1641,12 @@ def _assert_trees_all_equal_jittable( ) -def _assert_trees_all_close_static(*trees: ArrayTree, - rtol: float = 1e-06, - atol: float = .0) -> None: +def _assert_trees_all_close_static( + *trees: ArrayTree, + rtol: float = 1e-06, + atol: float = 0.0, + strict: bool = False, +) -> None: """Checks that all trees have leaves with approximately equal values. This compares the difference between values of actual and desired up to @@ -1653,6 +1656,9 @@ def _assert_trees_all_close_static(*trees: ArrayTree, *trees: A sequence of (at least 2) trees with array leaves. rtol: A relative tolerance. atol: An absolute tolerance. + strict: If True, raise an AssertionError when either the shape or the data + type of the arguments does not match. The special handling for scalars + mentioned in the Notes section of `np.allclose` is disabled. Raises: AssertionError: If actual and desired values are not equal up to @@ -1664,7 +1670,9 @@ def assert_fn(arr_1, arr_2): _ai.jnp_to_np_array(arr_2), rtol=rtol, atol=atol, - err_msg="Error in value equality check: Values not approximately equal") + err_msg="Error in value equality check: Values not approximately equal", + strict=strict, + ) def cmp_fn(arr_1, arr_2) -> bool: try: @@ -1685,10 +1693,19 @@ def err_msg_fn(arr_1, arr_2) -> str: assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees) -def _assert_trees_all_close_jittable(*trees: ArrayTree, - rtol: float = 1e-06, - atol: float = .0) -> Array: +def _assert_trees_all_close_jittable( + *trees: ArrayTree, + rtol: float = 1e-06, + atol: float = 0.0, + strict: bool = False, +) -> Array: """A jittable version of `_assert_trees_all_close_static`.""" + if strict: + raise NotImplementedError( + "`strict=True` is not implemented by" + " `_assert_trees_all_close_jittable`." + ) + err_msg_template = ( f"Values not approximately equal ({rtol=}, {atol=}): " + "{arr_1} != {arr_2}." diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index 6d5e9da..ed41b04 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1004,6 +1004,24 @@ def test_assert_trees_all_close_passes_values_close_but_not_equal(self): self.assertTrue( asserts._assert_trees_all_close_jittable(tree1, tree2, rtol=1e-6)) + def test_assert_trees_all_close_strict_mode(self): + # See 'notes' section of + # https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html + # for details about the 'strict' mode of `numpy.testing.assert_allclose`. + tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': jnp.array(0.0)} + tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': jnp.array(0.0)} + + asserts.assert_trees_all_close(tree1, tree2) + asserts.assert_trees_all_close(tree1, tree2, strict=False) + err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'') + with self.assertRaisesRegex(AssertionError, err_regex): + asserts.assert_trees_all_close(tree1, tree2, strict=True) + + # strict=True raises NotImplementedError for jittable version + err_regex_not_impl = r'`strict=True` is not implemented' + with self.assertRaisesRegex(NotImplementedError, err_regex_not_impl): + asserts._assert_trees_all_close_jittable(tree1, tree2, strict=True) + def test_assert_trees_all_close_bfloat16(self): tree1 = {'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16)} tree2 = {