From 01557900b2d41046c360c80e68cb8469d9598f94 Mon Sep 17 00:00:00 2001 From: sorata-kanda Date: Mon, 22 Dec 2025 21:37:11 +0530 Subject: [PATCH 1/2] Add assert_not_both_not_none (fixes #393) --- chex/_src/asserts.py | 14 ++++++++++++++ chex/_src/asserts_test.py | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 6c08c21..4c0696d 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -248,6 +248,20 @@ def assert_not_both_none(first: Any, second: Any) -> None: raise AssertionError( "At least one of the arguments must be different from `None`.") +@_static_assertion +def assert_not_both_not_none(first: Any, second: Any) -> None: + """Checks that not both arguments are non-None. + + Args: + first: A first object. + second: A second object. + + Raises: + AssertionError: If both ``first`` and ``second`` are not None. + """ + if first is not None and second is not None: + raise AssertionError( + "At most one of the arguments may be different from `None`.") @_static_assertion def assert_exactly_one_is_none(first: Any, second: Any) -> None: diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index 6d5e9da..622dcc3 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1860,7 +1860,15 @@ def test_assert_equal_fail(self, first, second): with self.assertRaises(AssertionError): asserts.assert_equal(first, second) +class NotNoneAssertionsTest(parameterized.TestCase): + def test_assert_not_both_not_none(self): + asserts.assert_not_both_not_none(None, None) + asserts.assert_not_both_not_none(1, None) + asserts.assert_not_both_not_none(None, 1) + + with self.assertRaises(AssertionError): + asserts.assert_not_both_not_none(1, 2) class IsDivisibleTest(parameterized.TestCase): def test_assert_is_divisible(self): From 0872c8a9dab1ef4075224ad473adbf0a8fd2137b Mon Sep 17 00:00:00 2001 From: sorata-kanda Date: Tue, 13 Jan 2026 15:23:47 +0530 Subject: [PATCH 2/2] Fix assert_equal to work with @chex.chexify and @jax.jit Fixes #387 Previously, assert_equal was a static-only assertion that failed with TracerBoolConversionError when used inside @chex.chexify + @jax.jit decorated functions, because it tried to compare traced values using unittest.TestCase().assertEqual(). This change adds a jittable implementation using jnp.array_equal() and registers assert_equal as a value assertion, following the same pattern as assert_trees_all_equal. --- chex/_src/asserts.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 4c0696d..e950bf1 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -216,9 +216,30 @@ def assert_gpu_available(backend: Optional[str] = None) -> None: raise AssertionError(f"No GPU devices available in {jax.devices(backend)}.") -@_static_assertion -def assert_equal(first: Any, second: Any) -> None: - """Checks that the two objects are equal as determined by the `==` operator. +def _assert_equal_static(first: Any, second: Any) -> None: + """Static (host) implementation of assert_equal.""" + unittest.TestCase().assertEqual(first, second) + + +def _assert_equal_jittable(first: Any, second: Any) -> Array: + """Jittable implementation of assert_equal for use with chexify.""" + first_arr = jnp.asarray(first) + second_arr = jnp.asarray(second) + are_equal = jnp.array_equal(first_arr, second_arr) + checkify.check(are_equal, "Values are not equal: {first} != {second}.", + first=first_arr, second=second_arr) + return are_equal + + +assert_equal = _value_assertion( + assert_fn=_assert_equal_static, + jittable_assert_fn=_assert_equal_jittable, + name="assert_equal", +) +assert_equal.__doc__ = """Checks that the two objects are equal. + + When used inside a ``@chex.chexify`` decorated function, this assertion + supports JAX tracing and can compare traced values. Arrays with more than one element cannot be compared. Use ``assert_trees_all_close`` to compare arrays. @@ -230,7 +251,6 @@ def assert_equal(first: Any, second: Any) -> None: Raises: AssertionError: If not ``(first == second)``. """ - unittest.TestCase().assertEqual(first, second) @_static_assertion