From 92737814c29beafb3943858e2f8d73612ab4b460 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 23 Apr 2025 20:33:43 -0700 Subject: [PATCH] [JAX] Replace GSPMDSharding with NamedSharding. This is to prepare for turning on jax_use_shardy_partitioner by default. PiperOrigin-RevId: 750827303 --- vmoe/checkpoints/partitioned.py | 6 ++++-- vmoe/partitioning.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vmoe/checkpoints/partitioned.py b/vmoe/checkpoints/partitioned.py index 99cf06d..cf04b85 100644 --- a/vmoe/checkpoints/partitioned.py +++ b/vmoe/checkpoints/partitioned.py @@ -56,8 +56,10 @@ IndexInfo = types.IndexInfo LazyArrayChunks = types.LazyArrayChunks MapResult = multiprocessing.pool.MapResult -GSPMDSharding = jax.sharding.GSPMDSharding PyTree = Any +Mesh = jax.sharding.Mesh +NamedSharding = jax.sharding.NamedSharding +PartitionSpec = jax.sharding.PartitionSpec Sharding = jax.sharding.Sharding Slice = types.Slice SliceNd = types.SliceNd @@ -361,7 +363,7 @@ def _get_array_sharding_or_default(arr: jax.Array) -> Sharding: if hasattr(arr, 'sharding'): return arr.sharding else: - return GSPMDSharding.get_replicated(jax.devices()) + return NamedSharding(Mesh(jax.devices(), 'x'), PartitionSpec()) def _intersect_slicend(a: SliceNd, b: SliceNd) -> Optional[SliceNd]: diff --git a/vmoe/partitioning.py b/vmoe/partitioning.py index 01c15d1..3783135 100644 --- a/vmoe/partitioning.py +++ b/vmoe/partitioning.py @@ -87,7 +87,7 @@ def get_array_sharding_or_default(arr: jax.Array) -> jax.sharding.Sharding: if hasattr(arr, 'sharding'): return arr.sharding else: - return jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + return NamedSharding(Mesh(jax.devices(), 'x'), PartitionSpec()) def process_has_contiguous_device_slice(devices: np.ndarray,