diff --git a/chex/__init__.py b/chex/__init__.py index 0da55ec..c0ce4c0 100644 --- a/chex/__init__.py +++ b/chex/__init__.py @@ -33,6 +33,7 @@ from chex._src.asserts import assert_is_divisible from chex._src.asserts import assert_max_traces from chex._src.asserts import assert_not_both_none +from chex._src.asserts import assert_not_both_not_none from chex._src.asserts import assert_numerical_grads from chex._src.asserts import assert_rank from chex._src.asserts import assert_scalar @@ -138,6 +139,7 @@ "assert_is_divisible", "assert_max_traces", "assert_not_both_none", + "assert_not_both_not_none", "assert_numerical_grads", "assert_rank", "assert_scalar", diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 6c08c21..6230d3f 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -249,6 +249,22 @@ def assert_not_both_none(first: Any, second: Any) -> None: "At least one of the arguments must be different from `None`.") +@_static_assertion +def assert_not_both_not_none(first: Any, second: Any) -> None: + """Checks that at most one of the arguments is not `None`. + + Args: + first: A first object. + second: A second object. + + Raises: + AssertionError: If ``(first is not None) and (second is not None)``. + """ + if first is not None and second is not None: + raise AssertionError( + "At most one of the arguments must be different from `None`.") + + @_static_assertion def assert_exactly_one_is_none(first: Any, second: Any) -> None: """Checks that one and only one of the arguments is `None`. diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index 6d5e9da..1b33562 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1871,6 +1871,28 @@ def test_assert_is_divisible_fail(self): asserts.assert_is_divisible(7, 3) + +class NoneCheckTest(parameterized.TestCase): + + def test_assert_not_both_not_none(self): + asserts.assert_not_both_not_none(None, None) + asserts.assert_not_both_not_none(1, None) + asserts.assert_not_both_not_none(None, 1) + + with self.assertRaisesRegex(AssertionError, + _get_err_regex('At most one')): + asserts.assert_not_both_not_none(1, 1) + + def test_assert_not_both_none(self): + asserts.assert_not_both_none(1, None) + asserts.assert_not_both_none(None, 1) + asserts.assert_not_both_none(1, 1) + + with self.assertRaisesRegex(AssertionError, + _get_err_regex('At least one')): + asserts.assert_not_both_none(None, None) + + if __name__ == '__main__': jax.config.update('jax_numpy_rank_promotion', 'raise') absltest.main()