From df8bc60d713c1ca6a63d367df1c4a7fda02ff815 Mon Sep 17 00:00:00 2001 From: Sikandar Date: Sun, 11 Jan 2026 02:37:11 +0500 Subject: [PATCH 1/2] Add chex.variants pytest example --- chex/_src/variants_pytest_example.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 chex/_src/variants_pytest_example.py diff --git a/chex/_src/variants_pytest_example.py b/chex/_src/variants_pytest_example.py new file mode 100644 index 0000000..cc35b65 --- /dev/null +++ b/chex/_src/variants_pytest_example.py @@ -0,0 +1,54 @@ +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example of using `chex.variants` with `pytest`.""" + +from typing import Callable +import chex +from chex._src import variants +import jax.numpy as jnp +import pytest + +# `chex.variants` is primarily designed for `unittest.TestCase` and `absl.testing`. +# When using `pytest`, you can manually parametrize your tests over +# `variants.ALL_VARIANTS` (or a subset thereof) to achieve similar coverage. +# +# Note: `chex.variants` manages different JAX execution modes (JIT, PMAP, etc.) +# by decorating the function-under-test. + + +@pytest.mark.parametrize("variant", variants.ALL_VARIANTS) +@pytest.mark.parametrize("n", [1, 2, 3]) +def test_variants_with_pytest(variant: Callable, n: int) -> None: + """Tests a function across all Chex variants using pytest parametrization. + + Args: + variant: A Chex variant decorator (e.g., with_jit, without_jit, etc.). + n: Input parameter for the test. + """ + + # Define the computation you want to test. + # The `@variant` decorator will apply the specific execution mode + # (e.g., wrap in jax.jit, jax.pmap) for this test iteration. + @variant + def computation(x): + return x * x + + # Execute the decorated function. + # Convert input to JAX array as some variants (like pmap) might expect it + # or handle it differently. + result = computation(jnp.array(n)) + + # Verify the result. + assert result == n * n From dd2a0f9d7478a7bc613b1945dc56d61ba719dc04 Mon Sep 17 00:00:00 2001 From: Sikandar Date: Mon, 12 Jan 2026 17:27:13 +0500 Subject: [PATCH 2/2] Fix #389: Add vmap compatibility to scalar assertions - Convert assert_scalar_positive, assert_scalar_non_negative, and assert_scalar_negative from static to value assertions - Add jittable predicate functions using checkify.check - Enable usage with jax.vmap, jax.jit, and jax.pmap when wrapped with @chex.chexify - Add comprehensive test suite with 9 tests covering vmap, jit, pmap, and nested vmap - Maintain full backward compatibility for non-jitted usage - Fixes issue #389 --- chex/_src/asserts.py | 82 +++++++++--- chex/_src/test_scalar_assertions_vmap.py | 152 +++++++++++++++++++++++ 2 files changed, 216 insertions(+), 18 deletions(-) create mode 100644 chex/_src/test_scalar_assertions_vmap.py diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index ce5a69e..eb3bd8b 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -322,9 +322,56 @@ def assert_scalar_in(x: Any, f"The argument must be in ({min_}, {max_}), got {x}.") -@_static_assertion -def assert_scalar_positive(x: Scalar) -> None: - """Checks that a scalar is positive. +def _jittable_assert_scalar_positive(x: Scalar) -> Array: + """Jittable version of assert_scalar_positive.""" + pred = x > 0 + checkify.check(pred, "The argument must be positive, got {}.", x) + return pred + + +def _jittable_assert_scalar_non_negative(x: Scalar) -> Array: + """Jittable version of assert_scalar_non_negative.""" + pred = x >= 0 + checkify.check(pred, "The argument must be non-negative, was {}.", x) + return pred + + +def _jittable_assert_scalar_negative(x: Scalar) -> Array: + """Jittable version of assert_scalar_negative.""" + pred = x < 0 + checkify.check(pred, "The argument must be negative, was {}.", x) + return pred + + +def _assert_scalar_positive_impl(x: Scalar) -> None: + """Host implementation of assert_scalar_positive.""" + assert_scalar(x) + if x <= 0: + raise AssertionError(f"The argument must be positive, got {x}.") + + +def _assert_scalar_non_negative_impl(x: Scalar) -> None: + """Host implementation of assert_scalar_non_negative.""" + assert_scalar(x) + if x < 0: + raise AssertionError(f"The argument must be non-negative, was {x}.") + + +def _assert_scalar_negative_impl(x: Scalar) -> None: + """Host implementation of assert_scalar_negative.""" + assert_scalar(x) + if x >= 0: + raise AssertionError(f"The argument must be negative, was {x}.") + + +assert_scalar_positive = _ai.chex_assertion( + assert_fn=_assert_scalar_positive_impl, + jittable_assert_fn=_jittable_assert_scalar_positive, + name='assert_scalar_positive') +assert_scalar_positive.__doc__ = """Checks that a scalar is positive. + + This assertion is compatible with JAX transformations (jit, vmap, pmap) + when used inside a function wrapped with ``@chex.chexify``. Args: x: A value to check. @@ -332,14 +379,15 @@ def assert_scalar_positive(x: Scalar) -> None: Raises: AssertionError: If ``x`` is not a scalar or strictly positive. """ - assert_scalar(x) - if x <= 0: - raise AssertionError(f"The argument must be positive, got {x}.") +assert_scalar_non_negative = _ai.chex_assertion( + assert_fn=_assert_scalar_non_negative_impl, + jittable_assert_fn=_jittable_assert_scalar_non_negative, + name='assert_scalar_non_negative') +assert_scalar_non_negative.__doc__ = """Checks that a scalar is non-negative. -@_static_assertion -def assert_scalar_non_negative(x: Scalar) -> None: - """Checks that a scalar is non-negative. + This assertion is compatible with JAX transformations (jit, vmap, pmap) + when used inside a function wrapped with ``@chex.chexify``. Args: x: A value to check. @@ -347,14 +395,15 @@ def assert_scalar_non_negative(x: Scalar) -> None: Raises: AssertionError: If ``x`` is not a scalar or negative. """ - assert_scalar(x) - if x < 0: - raise AssertionError(f"The argument must be non-negative, was {x}.") +assert_scalar_negative = _ai.chex_assertion( + assert_fn=_assert_scalar_negative_impl, + jittable_assert_fn=_jittable_assert_scalar_negative, + name='assert_scalar_negative') +assert_scalar_negative.__doc__ = """Checks that a scalar is negative. -@_static_assertion -def assert_scalar_negative(x: Scalar) -> None: - """Checks that a scalar is negative. + This assertion is compatible with JAX transformations (jit, vmap, pmap) + when used inside a function wrapped with ``@chex.chexify``. Args: x: A value to check. @@ -362,9 +411,6 @@ def assert_scalar_negative(x: Scalar) -> None: Raises: AssertionError: If ``x`` is not a scalar or strictly negative. """ - assert_scalar(x) - if x >= 0: - raise AssertionError(f"The argument must be negative, was {x}.") @_static_assertion diff --git a/chex/_src/test_scalar_assertions_vmap.py b/chex/_src/test_scalar_assertions_vmap.py new file mode 100644 index 0000000..2c610c3 --- /dev/null +++ b/chex/_src/test_scalar_assertions_vmap.py @@ -0,0 +1,152 @@ +"""Tests for scalar assertions with vmap compatibility (Issue #389).""" + +from absl.testing import absltest +from absl.testing import parameterized +from chex._src import asserts +from chex._src import asserts_chexify +import jax +import jax.numpy as jnp + + +class ScalarAssertionsVmapTest(parameterized.TestCase): + """Tests for vmap-compatible scalar assertions.""" + + def test_assert_scalar_positive_with_vmap_mwe(self): + """Test the exact MWE from issue #389.""" + x_scalar = 1. + x_vector = jnp.array([1., 1.]) + + # Test 1: Scalar (should work as before, without chexify) + asserts.assert_scalar_positive(x_scalar) # Works + + # Test 2: Vector with vmap (the requested feature) + def test_vmap(x): + jax.vmap(asserts.assert_scalar_positive)(x) + return x + + test_vmap_chexified = asserts_chexify.chexify(jax.jit(test_vmap), async_check=False) + + # Should pass + result = test_vmap_chexified(x_vector) + self.assertTrue(jnp.array_equal(result, x_vector)) + + def test_assert_scalar_positive_with_vmap_failure(self): + """Test that vmap correctly detects negative values.""" + x_negative = jnp.array([1., -1., 3.]) + + def check_positive(x): + jax.vmap(asserts.assert_scalar_positive)(x) + return x + + check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False) + + # Should fail + with self.assertRaisesRegex(AssertionError, 'must be positive'): + check_positive_chexified(x_negative) + + @parameterized.parameters( + ('assert_scalar_positive', jnp.array([1., 2., 3.]), jnp.array([1., -1., 3.])), + ('assert_scalar_non_negative', jnp.array([0., 1., 2.]), jnp.array([0., -1., 2.])), + ('assert_scalar_negative', jnp.array([-1., -2., -3.]), jnp.array([-1., 1., -3.])), + ) + def test_scalar_assertions_vmap(self, assertion_name, valid_input, invalid_input): + """Test all scalar assertions with vmap.""" + assertion_fn = getattr(asserts, assertion_name) + + def check_fn(x): + jax.vmap(assertion_fn)(x) + return x + + check_fn_chexified = asserts_chexify.chexify(jax.jit(check_fn), async_check=False) + + # Valid input should pass + result = check_fn_chexified(valid_input) + self.assertTrue(jnp.array_equal(result, valid_input)) + + # Invalid input should fail + with self.assertRaises(AssertionError): + check_fn_chexified(invalid_input) + + def test_backward_compatibility_without_chexify(self): + """Test that assertions still work without chexify for scalars.""" + # These should all work as before + asserts.assert_scalar_positive(1.0) + asserts.assert_scalar_positive(5) + asserts.assert_scalar_non_negative(0.0) + asserts.assert_scalar_non_negative(1) + asserts.assert_scalar_negative(-1.0) + asserts.assert_scalar_negative(-5) + + # These should all fail + with self.assertRaisesRegex(AssertionError, 'must be positive'): + asserts.assert_scalar_positive(-1.0) + + with self.assertRaisesRegex(AssertionError, 'must be non-negative'): + asserts.assert_scalar_non_negative(-1.0) + + with self.assertRaisesRegex(AssertionError, 'must be negative'): + asserts.assert_scalar_negative(1.0) + + def test_with_jit_only(self): + """Test that assertions work with jit (without vmap).""" + def check_positive(x): + asserts.assert_scalar_positive(x) + return x * 2 + + check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False) + + # Should pass + result = check_positive_chexified(5.0) + self.assertEqual(result, 10.0) + + # Should fail + with self.assertRaisesRegex(AssertionError, 'must be positive'): + check_positive_chexified(-5.0) + + def test_with_pmap(self): + """Test that assertions work with pmap.""" + devices = jax.local_devices() + if len(devices) < 2: + self.skipTest('Test requires at least 2 devices') + + def check_positive(x): + asserts.assert_scalar_positive(x) + return x * 2 + + check_positive_pmapped = asserts_chexify.chexify(jax.pmap(check_positive), async_check=False) + + # Should pass + x_valid = jnp.array([1.0, 2.0]) + result = check_positive_pmapped(x_valid) + self.assertTrue(jnp.array_equal(result, x_valid * 2)) + + # Should fail + x_invalid = jnp.array([1.0, -2.0]) + with self.assertRaisesRegex(AssertionError, 'must be positive'): + check_positive_pmapped(x_invalid) + + def test_nested_vmap(self): + """Test that assertions work with nested vmap.""" + def check_positive(x): + # x has shape (3, 4) + jax.vmap(jax.vmap(asserts.assert_scalar_positive))(x) + return x + + check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False) + + # Should pass + x_valid = jnp.ones((3, 4)) + result = check_positive_chexified(x_valid) + self.assertTrue(jnp.array_equal(result, x_valid)) + + # Should fail + x_invalid = jnp.array([[1., 2., 3., 4.], + [5., -6., 7., 8.], + [9., 10., 11., 12.]]) + with self.assertRaisesRegex(AssertionError, 'must be positive'): + check_positive_chexified(x_invalid) + + +if __name__ == '__main__': + jax.config.update('jax_numpy_rank_promotion', 'raise') + absltest.main()