diff --git a/vmoe/checkpoints/partitioned.py b/vmoe/checkpoints/partitioned.py index 99cf06d..cf04b85 100644 --- a/vmoe/checkpoints/partitioned.py +++ b/vmoe/checkpoints/partitioned.py @@ -56,8 +56,10 @@ IndexInfo = types.IndexInfo LazyArrayChunks = types.LazyArrayChunks MapResult = multiprocessing.pool.MapResult -GSPMDSharding = jax.sharding.GSPMDSharding PyTree = Any +Mesh = jax.sharding.Mesh +NamedSharding = jax.sharding.NamedSharding +PartitionSpec = jax.sharding.PartitionSpec Sharding = jax.sharding.Sharding Slice = types.Slice SliceNd = types.SliceNd @@ -361,7 +363,7 @@ def _get_array_sharding_or_default(arr: jax.Array) -> Sharding: if hasattr(arr, 'sharding'): return arr.sharding else: - return GSPMDSharding.get_replicated(jax.devices()) + return NamedSharding(Mesh(jax.devices(), 'x'), PartitionSpec()) def _intersect_slicend(a: SliceNd, b: SliceNd) -> Optional[SliceNd]: diff --git a/vmoe/partitioning.py b/vmoe/partitioning.py index 01c15d1..3783135 100644 --- a/vmoe/partitioning.py +++ b/vmoe/partitioning.py @@ -87,7 +87,7 @@ def get_array_sharding_or_default(arr: jax.Array) -> jax.sharding.Sharding: if hasattr(arr, 'sharding'): return arr.sharding else: - return jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + return NamedSharding(Mesh(jax.devices(), 'x'), PartitionSpec()) def process_has_contiguous_device_slice(devices: np.ndarray,