Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,10 +1060,15 @@ def _assert_fn(path, leaf):
# This is for backwards compatibility.
def _check_sharding(x):
if hasattr(jax, "Array") and isinstance(x, jax.Array):
if not jax.typeof(x).sharding.is_fully_replicated:
return True
# Use x.sharding directly for concrete arrays.
sharding = getattr(x, "sharding", None)
if sharding is not None:
return not sharding.is_fully_replicated or len(sharding.device_set) > 1
else:
return len(x.sharding.device_set) > 1
# Fallback for backward compatibility.
# Many older JAX versions or specific tracers might still need jax.typeof.
sharding = jax.typeof(x).sharding
return not sharding.is_fully_replicated or len(sharding.device_set) > 1
# pytype: disable=attribute-error
return (
hasattr(jax, "pxla")
Expand Down Expand Up @@ -1130,16 +1135,16 @@ def _assert_fn(path, leaf):
f" on {leaf_device}."
)
else:
errors.append((
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' resides "
f"on {leaf.devices()} (CPU devices are disallowed)."
))
)
else:
# Not a jax.Array.
errors.append((
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' has "
f"unexpected type: {type(leaf)}."
))
)

for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
Expand Down Expand Up @@ -1183,27 +1188,31 @@ def _assert_fn(path, leaf):
# Check that the leaf is a DeviceArray.
if isinstance(leaf, jax.Array):
if _check_sharding(leaf):
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' is a "
f"ShardedDeviceArray which are disallowed. "
f" (type={type(leaf)})."))
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' is a "
f"sharded JAX array (historically ShardedDeviceArray) "
f"which are disallowed. (type={type(leaf)})."
)
else: # DeviceArray and not ShardedDeviceArray
# Check the platform.
leaf_device = list(leaf.devices())[0]
if leaf_device.platform not in platform:
errors.append((
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' resides on "
f"'{leaf_device.platform}', expected '{platform}'."
))
)

# Check the device.
if device is not None and leaf.devices() != {device}:
errors.append((
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' resides on "
f"{leaf.devices()}, expected {device}."
))
)
else:
errors.append((f"Tree leaf '{_ai.format_tree_path(path)}' has "
f"unexpected type: {type(leaf)}."))
errors.append(
f"Tree leaf '{_ai.format_tree_path(path)}' has "
f"unexpected type: {type(leaf)}."
)

for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
_assert_fn(_ai.convert_jax_path_to_dm_path(path), leaf)
Expand Down
Loading