diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 6c08c21..e950bf1 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -216,9 +216,30 @@ def assert_gpu_available(backend: Optional[str] = None) -> None: raise AssertionError(f"No GPU devices available in {jax.devices(backend)}.") -@_static_assertion -def assert_equal(first: Any, second: Any) -> None: - """Checks that the two objects are equal as determined by the `==` operator. +def _assert_equal_static(first: Any, second: Any) -> None: + """Static (host) implementation of assert_equal.""" + unittest.TestCase().assertEqual(first, second) + + +def _assert_equal_jittable(first: Any, second: Any) -> Array: + """Jittable implementation of assert_equal for use with chexify.""" + first_arr = jnp.asarray(first) + second_arr = jnp.asarray(second) + are_equal = jnp.array_equal(first_arr, second_arr) + checkify.check(are_equal, "Values are not equal: {first} != {second}.", + first=first_arr, second=second_arr) + return are_equal + + +assert_equal = _value_assertion( + assert_fn=_assert_equal_static, + jittable_assert_fn=_assert_equal_jittable, + name="assert_equal", +) +assert_equal.__doc__ = """Checks that the two objects are equal. + + When used inside a ``@chex.chexify`` decorated function, this assertion + supports JAX tracing and can compare traced values. Arrays with more than one element cannot be compared. Use ``assert_trees_all_close`` to compare arrays. @@ -230,7 +251,6 @@ def assert_equal(first: Any, second: Any) -> None: Raises: AssertionError: If not ``(first == second)``. """ - unittest.TestCase().assertEqual(first, second) @_static_assertion @@ -248,6 +268,20 @@ def assert_not_both_none(first: Any, second: Any) -> None: raise AssertionError( "At least one of the arguments must be different from `None`.") +@_static_assertion +def assert_not_both_not_none(first: Any, second: Any) -> None: + """Checks that not both arguments are non-None. + + Args: + first: A first object. + second: A second object. + + Raises: + AssertionError: If both ``first`` and ``second`` are not None. + """ + if first is not None and second is not None: + raise AssertionError( + "At most one of the arguments may be different from `None`.") @_static_assertion def assert_exactly_one_is_none(first: Any, second: Any) -> None: diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index 6d5e9da..622dcc3 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1860,7 +1860,15 @@ def test_assert_equal_fail(self, first, second): with self.assertRaises(AssertionError): asserts.assert_equal(first, second) +class NotNoneAssertionsTest(parameterized.TestCase): + def test_assert_not_both_not_none(self): + asserts.assert_not_both_not_none(None, None) + asserts.assert_not_both_not_none(1, None) + asserts.assert_not_both_not_none(None, 1) + + with self.assertRaises(AssertionError): + asserts.assert_not_both_not_none(1, 2) class IsDivisibleTest(parameterized.TestCase): def test_assert_is_divisible(self):