Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,9 +1641,12 @@ def _assert_trees_all_equal_jittable(
)


def _assert_trees_all_close_static(*trees: ArrayTree,
rtol: float = 1e-06,
atol: float = .0) -> None:
def _assert_trees_all_close_static(
*trees: ArrayTree,
rtol: float = 1e-06,
atol: float = 0.0,
strict: bool = False,
) -> None:
"""Checks that all trees have leaves with approximately equal values.

This compares the difference between values of actual and desired up to
Expand All @@ -1653,6 +1656,9 @@ def _assert_trees_all_close_static(*trees: ArrayTree,
*trees: A sequence of (at least 2) trees with array leaves.
rtol: A relative tolerance.
atol: An absolute tolerance.
strict: If True, raise an AssertionError when either the shape or the data
type of the arguments does not match. The special handling for scalars
mentioned in the Notes section of `np.allclose` is disabled.

Raises:
AssertionError: If actual and desired values are not equal up to
Expand All @@ -1664,7 +1670,9 @@ def assert_fn(arr_1, arr_2):
_ai.jnp_to_np_array(arr_2),
rtol=rtol,
atol=atol,
err_msg="Error in value equality check: Values not approximately equal")
err_msg="Error in value equality check: Values not approximately equal",
strict=strict,
)

def cmp_fn(arr_1, arr_2) -> bool:
try:
Expand All @@ -1685,10 +1693,19 @@ def err_msg_fn(arr_1, arr_2) -> str:
assert_trees_all_equal_comparator(cmp_fn, err_msg_fn, *trees)


def _assert_trees_all_close_jittable(*trees: ArrayTree,
rtol: float = 1e-06,
atol: float = .0) -> Array:
def _assert_trees_all_close_jittable(
*trees: ArrayTree,
rtol: float = 1e-06,
atol: float = 0.0,
strict: bool = False,
) -> Array:
"""A jittable version of `_assert_trees_all_close_static`."""
if strict:
raise NotImplementedError(
"`strict=True` is not implemented by"
" `_assert_trees_all_close_jittable`."
)

err_msg_template = (
f"Values not approximately equal ({rtol=}, {atol=}): "
+ "{arr_1} != {arr_2}."
Expand Down
18 changes: 18 additions & 0 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,24 @@ def test_assert_trees_all_close_passes_values_close_but_not_equal(self):
self.assertTrue(
asserts._assert_trees_all_close_jittable(tree1, tree2, rtol=1e-6))

def test_assert_trees_all_close_strict_mode(self):
# See 'notes' section of
# https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html
# for details about the 'strict' mode of `numpy.testing.assert_allclose`.
tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': jnp.array(0.0)}
tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': jnp.array(0.0)}

asserts.assert_trees_all_close(tree1, tree2)
asserts.assert_trees_all_close(tree1, tree2, strict=False)
err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close(tree1, tree2, strict=True)

# strict=True raises NotImplementedError for jittable version
err_regex_not_impl = r'`strict=True` is not implemented'
with self.assertRaisesRegex(NotImplementedError, err_regex_not_impl):
asserts._assert_trees_all_close_jittable(tree1, tree2, strict=True)

def test_assert_trees_all_close_bfloat16(self):
tree1 = {'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16)}
tree2 = {
Expand Down
Loading