diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index 6c08c21..99d6eff 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -1060,10 +1060,15 @@ def _assert_fn(path, leaf): # This is for backwards compatibility. def _check_sharding(x): if hasattr(jax, "Array") and isinstance(x, jax.Array): - if not jax.typeof(x).sharding.is_fully_replicated: - return True + # Use x.sharding directly for concrete arrays. + sharding = getattr(x, "sharding", None) + if sharding is not None: + return not sharding.is_fully_replicated or len(sharding.device_set) > 1 else: - return len(x.sharding.device_set) > 1 + # Fallback for backward compatibility. + # Many older JAX versions or specific tracers might still need jax.typeof. + sharding = jax.typeof(x).sharding + return not sharding.is_fully_replicated or len(sharding.device_set) > 1 # pytype: disable=attribute-error return ( hasattr(jax, "pxla") @@ -1130,16 +1135,16 @@ def _assert_fn(path, leaf): f" on {leaf_device}." ) else: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides " f"on {leaf.devices()} (CPU devices are disallowed)." - )) + ) else: # Not a jax.Array. - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' has " f"unexpected type: {type(leaf)}." - )) + ) for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]: _assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf) @@ -1183,27 +1188,31 @@ def _assert_fn(path, leaf): # Check that the leaf is a DeviceArray. if isinstance(leaf, jax.Array): if _check_sharding(leaf): - errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is a " - f"ShardedDeviceArray which are disallowed. " - f" (type={type(leaf)}).")) + errors.append( + f"Tree leaf '{_ai.format_tree_path(path)}' is a " + f"sharded JAX array (historically ShardedDeviceArray) " + f"which are disallowed. (type={type(leaf)})." + ) else: # DeviceArray and not ShardedDeviceArray # Check the platform. leaf_device = list(leaf.devices())[0] if leaf_device.platform not in platform: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides on " f"'{leaf_device.platform}', expected '{platform}'." - )) + ) # Check the device. if device is not None and leaf.devices() != {device}: - errors.append(( + errors.append( f"Tree leaf '{_ai.format_tree_path(path)}' resides on " f"{leaf.devices()}, expected {device}." - )) + ) else: - errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' has " - f"unexpected type: {type(leaf)}.")) + errors.append( + f"Tree leaf '{_ai.format_tree_path(path)}' has " + f"unexpected type: {type(leaf)}." + ) for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]: _assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)