Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,49 +322,95 @@ def assert_scalar_in(x: Any,
f"The argument must be in ({min_}, {max_}), got {x}.")


@_static_assertion
def assert_scalar_positive(x: Scalar) -> None:
"""Checks that a scalar is positive.
def _jittable_assert_scalar_positive(x: Scalar) -> Array:
"""Jittable version of assert_scalar_positive."""
pred = x > 0
checkify.check(pred, "The argument must be positive, got {}.", x)
return pred


def _jittable_assert_scalar_non_negative(x: Scalar) -> Array:
"""Jittable version of assert_scalar_non_negative."""
pred = x >= 0
checkify.check(pred, "The argument must be non-negative, was {}.", x)
return pred


def _jittable_assert_scalar_negative(x: Scalar) -> Array:
"""Jittable version of assert_scalar_negative."""
pred = x < 0
checkify.check(pred, "The argument must be negative, was {}.", x)
return pred


def _assert_scalar_positive_impl(x: Scalar) -> None:
"""Host implementation of assert_scalar_positive."""
assert_scalar(x)
if x <= 0:
raise AssertionError(f"The argument must be positive, got {x}.")


def _assert_scalar_non_negative_impl(x: Scalar) -> None:
"""Host implementation of assert_scalar_non_negative."""
assert_scalar(x)
if x < 0:
raise AssertionError(f"The argument must be non-negative, was {x}.")


def _assert_scalar_negative_impl(x: Scalar) -> None:
"""Host implementation of assert_scalar_negative."""
assert_scalar(x)
if x >= 0:
raise AssertionError(f"The argument must be negative, was {x}.")


assert_scalar_positive = _ai.chex_assertion(
assert_fn=_assert_scalar_positive_impl,
jittable_assert_fn=_jittable_assert_scalar_positive,
name='assert_scalar_positive')
assert_scalar_positive.__doc__ = """Checks that a scalar is positive.

This assertion is compatible with JAX transformations (jit, vmap, pmap)
when used inside a function wrapped with ``@chex.chexify``.

Args:
x: A value to check.

Raises:
AssertionError: If ``x`` is not a scalar or strictly positive.
"""
assert_scalar(x)
if x <= 0:
raise AssertionError(f"The argument must be positive, got {x}.")

assert_scalar_non_negative = _ai.chex_assertion(
assert_fn=_assert_scalar_non_negative_impl,
jittable_assert_fn=_jittable_assert_scalar_non_negative,
name='assert_scalar_non_negative')
assert_scalar_non_negative.__doc__ = """Checks that a scalar is non-negative.

@_static_assertion
def assert_scalar_non_negative(x: Scalar) -> None:
"""Checks that a scalar is non-negative.
This assertion is compatible with JAX transformations (jit, vmap, pmap)
when used inside a function wrapped with ``@chex.chexify``.

Args:
x: A value to check.

Raises:
AssertionError: If ``x`` is not a scalar or negative.
"""
assert_scalar(x)
if x < 0:
raise AssertionError(f"The argument must be non-negative, was {x}.")

assert_scalar_negative = _ai.chex_assertion(
assert_fn=_assert_scalar_negative_impl,
jittable_assert_fn=_jittable_assert_scalar_negative,
name='assert_scalar_negative')
assert_scalar_negative.__doc__ = """Checks that a scalar is negative.

@_static_assertion
def assert_scalar_negative(x: Scalar) -> None:
"""Checks that a scalar is negative.
This assertion is compatible with JAX transformations (jit, vmap, pmap)
when used inside a function wrapped with ``@chex.chexify``.

Args:
x: A value to check.

Raises:
AssertionError: If ``x`` is not a scalar or strictly negative.
"""
assert_scalar(x)
if x >= 0:
raise AssertionError(f"The argument must be negative, was {x}.")


@_static_assertion
Expand Down
152 changes: 152 additions & 0 deletions chex/_src/test_scalar_assertions_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Tests for scalar assertions with vmap compatibility (Issue #389)."""

from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import asserts_chexify
import jax
import jax.numpy as jnp


class ScalarAssertionsVmapTest(parameterized.TestCase):
"""Tests for vmap-compatible scalar assertions."""

def test_assert_scalar_positive_with_vmap_mwe(self):
"""Test the exact MWE from issue #389."""
x_scalar = 1.
x_vector = jnp.array([1., 1.])

# Test 1: Scalar (should work as before, without chexify)
asserts.assert_scalar_positive(x_scalar) # Works

# Test 2: Vector with vmap (the requested feature)
def test_vmap(x):
jax.vmap(asserts.assert_scalar_positive)(x)
return x

test_vmap_chexified = asserts_chexify.chexify(jax.jit(test_vmap), async_check=False)

# Should pass
result = test_vmap_chexified(x_vector)
self.assertTrue(jnp.array_equal(result, x_vector))

def test_assert_scalar_positive_with_vmap_failure(self):
"""Test that vmap correctly detects negative values."""
x_negative = jnp.array([1., -1., 3.])

def check_positive(x):
jax.vmap(asserts.assert_scalar_positive)(x)
return x

check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False)

# Should fail
with self.assertRaisesRegex(AssertionError, 'must be positive'):
check_positive_chexified(x_negative)

@parameterized.parameters(
('assert_scalar_positive', jnp.array([1., 2., 3.]), jnp.array([1., -1., 3.])),
('assert_scalar_non_negative', jnp.array([0., 1., 2.]), jnp.array([0., -1., 2.])),
('assert_scalar_negative', jnp.array([-1., -2., -3.]), jnp.array([-1., 1., -3.])),
)
def test_scalar_assertions_vmap(self, assertion_name, valid_input, invalid_input):
"""Test all scalar assertions with vmap."""
assertion_fn = getattr(asserts, assertion_name)

def check_fn(x):
jax.vmap(assertion_fn)(x)
return x

check_fn_chexified = asserts_chexify.chexify(jax.jit(check_fn), async_check=False)

# Valid input should pass
result = check_fn_chexified(valid_input)
self.assertTrue(jnp.array_equal(result, valid_input))

# Invalid input should fail
with self.assertRaises(AssertionError):
check_fn_chexified(invalid_input)

def test_backward_compatibility_without_chexify(self):
"""Test that assertions still work without chexify for scalars."""
# These should all work as before
asserts.assert_scalar_positive(1.0)
asserts.assert_scalar_positive(5)
asserts.assert_scalar_non_negative(0.0)
asserts.assert_scalar_non_negative(1)
asserts.assert_scalar_negative(-1.0)
asserts.assert_scalar_negative(-5)

# These should all fail
with self.assertRaisesRegex(AssertionError, 'must be positive'):
asserts.assert_scalar_positive(-1.0)

with self.assertRaisesRegex(AssertionError, 'must be non-negative'):
asserts.assert_scalar_non_negative(-1.0)

with self.assertRaisesRegex(AssertionError, 'must be negative'):
asserts.assert_scalar_negative(1.0)

def test_with_jit_only(self):
"""Test that assertions work with jit (without vmap)."""
def check_positive(x):
asserts.assert_scalar_positive(x)
return x * 2

check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False)

# Should pass
result = check_positive_chexified(5.0)
self.assertEqual(result, 10.0)

# Should fail
with self.assertRaisesRegex(AssertionError, 'must be positive'):
check_positive_chexified(-5.0)

def test_with_pmap(self):
"""Test that assertions work with pmap."""
devices = jax.local_devices()
if len(devices) < 2:
self.skipTest('Test requires at least 2 devices')

def check_positive(x):
asserts.assert_scalar_positive(x)
return x * 2

check_positive_pmapped = asserts_chexify.chexify(jax.pmap(check_positive), async_check=False)

# Should pass
x_valid = jnp.array([1.0, 2.0])
result = check_positive_pmapped(x_valid)
self.assertTrue(jnp.array_equal(result, x_valid * 2))

# Should fail
x_invalid = jnp.array([1.0, -2.0])
with self.assertRaisesRegex(AssertionError, 'must be positive'):
check_positive_pmapped(x_invalid)

def test_nested_vmap(self):
"""Test that assertions work with nested vmap."""
def check_positive(x):
# x has shape (3, 4)
jax.vmap(jax.vmap(asserts.assert_scalar_positive))(x)
return x

check_positive_chexified = asserts_chexify.chexify(jax.jit(check_positive), async_check=False)

# Should pass
x_valid = jnp.ones((3, 4))
result = check_positive_chexified(x_valid)
self.assertTrue(jnp.array_equal(result, x_valid))

# Should fail
x_invalid = jnp.array([[1., 2., 3., 4.],
[5., -6., 7., 8.],
[9., 10., 11., 12.]])
with self.assertRaisesRegex(AssertionError, 'must be positive'):
check_positive_chexified(x_invalid)


if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()
54 changes: 54 additions & 0 deletions chex/_src/variants_pytest_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of using `chex.variants` with `pytest`."""

from typing import Callable
import chex
from chex._src import variants
import jax.numpy as jnp
import pytest

# `chex.variants` is primarily designed for `unittest.TestCase` and `absl.testing`.
# When using `pytest`, you can manually parametrize your tests over
# `variants.ALL_VARIANTS` (or a subset thereof) to achieve similar coverage.
#
# Note: `chex.variants` manages different JAX execution modes (JIT, PMAP, etc.)
# by decorating the function-under-test.


@pytest.mark.parametrize("variant", variants.ALL_VARIANTS)
@pytest.mark.parametrize("n", [1, 2, 3])
def test_variants_with_pytest(variant: Callable, n: int) -> None:
"""Tests a function across all Chex variants using pytest parametrization.

Args:
variant: A Chex variant decorator (e.g., with_jit, without_jit, etc.).
n: Input parameter for the test.
"""

# Define the computation you want to test.
# The `@variant` decorator will apply the specific execution mode
# (e.g., wrap in jax.jit, jax.pmap) for this test iteration.
@variant
def computation(x):
return x * x

# Execute the decorated function.
# Convert input to JAX array as some variants (like pmap) might expect it
# or handle it differently.
result = computation(jnp.array(n))

# Verify the result.
assert result == n * n