From 25107d3a11ab294c8d6912e6fc6514f538fba5eb Mon Sep 17 00:00:00 2001 From: ashutosh0x Date: Mon, 5 Jan 2026 03:55:20 +0530 Subject: [PATCH 1/2] Fix: Modernize sharding detection and error messages for JAX 0.8.2 compatibility (#422) --- chex/_src/asserts.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 6c08c21..461d3f4 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1060,10 +1060,13 @@ def _assert_fn(path, leaf): # This is for backwards compatibility. def _check_sharding(x): if hasattr(jax, "Array") and isinstance(x, jax.Array): - if not jax.typeof(x).sharding.is_fully_replicated: - return True - else: - return len(x.sharding.device_set) > 1 + # Use x.sharding directly for concrete arrays. + sharding = getattr(x, 'sharding', None) + if sharding is not None: + if not sharding.is_fully_replicated: + return True + else: + return len(sharding.device_set) > 1 # pytype: disable=attribute-error return ( hasattr(jax, "pxla") @@ -1184,8 +1187,8 @@ def _assert_fn(path, leaf): if isinstance(leaf, jax.Array): if _check_sharding(leaf): errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is a " - f"ShardedDeviceArray which are disallowed. " - f" (type={type(leaf)}).")) + f"sharded JAX array (historically ShardedDeviceArray) " + f"which are disallowed. (type={type(leaf)}).")) else: # DeviceArray and not ShardedDeviceArray # Check the platform. leaf_device = list(leaf.devices())[0] From ba1e5fcd6f45934527048fdd311c167539799536 Mon Sep 17 00:00:00 2001 From: Ashutosh0x Date: Tue, 6 Jan 2026 19:41:18 +0530 Subject: [PATCH 2/2] fix: address review comments (compatibility and style adjustments) --- chex/_src/asserts.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 461d3f4..99d6eff 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1061,12 +1061,14 @@ def _assert_fn(path, leaf): def _check_sharding(x): if hasattr(jax, "Array") and isinstance(x, jax.Array): # Use x.sharding directly for concrete arrays. - sharding = getattr(x, 'sharding', None) + sharding = getattr(x, "sharding", None) if sharding is not None: - if not sharding.is_fully_replicated: - return True - else: - return len(sharding.device_set) > 1 + return not sharding.is_fully_replicated or len(sharding.device_set) > 1 + else: + # Fallback for backward compatibility. + # Many older JAX versions or specific tracers might still need jax.typeof. + sharding = jax.typeof(x).sharding + return not sharding.is_fully_replicated or len(sharding.device_set) > 1 # pytype: disable=attribute-error return ( hasattr(jax, "pxla") @@ -1133,16 +1135,16 @@ def _assert_fn(path, leaf): f" on {leaf_device}." ) else: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides " f"on {leaf.devices()} (CPU devices are disallowed)." - )) + ) else: # Not a jax.Array. - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' has " f"unexpected type: {type(leaf)}." - )) + ) for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]: _assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf) @@ -1186,27 +1188,31 @@ def _assert_fn(path, leaf): # Check that the leaf is a DeviceArray. if isinstance(leaf, jax.Array): if _check_sharding(leaf): - errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is a " - f"sharded JAX array (historically ShardedDeviceArray) " - f"which are disallowed. (type={type(leaf)}).")) + errors.append( + f"Tree leaf '{_ai.format_tree_path(path)}' is a " + f"sharded JAX array (historically ShardedDeviceArray) " + f"which are disallowed. (type={type(leaf)})." + ) else: # DeviceArray and not ShardedDeviceArray # Check the platform. leaf_device = list(leaf.devices())[0] if leaf_device.platform not in platform: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides on " f"'{leaf_device.platform}', expected '{platform}'." - )) + ) # Check the device. if device is not None and leaf.devices() != {device}: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides on " f"{leaf.devices()}, expected {device}." - )) + ) else: - errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' has " - f"unexpected type: {type(leaf)}.")) + errors.append( + f"Tree leaf '{_ai.format_tree_path(path)}' has " + f"unexpected type: {type(leaf)}." + ) for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]: _assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)