Skip to content

Commit 0486227

Browse files
committed
use another strategy to get jax version'
1 parent bf749ce commit 0486227

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

axlearn/common/array_serialization.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src import array, typing
3737
from jax._src.layout import Layout
3838
from jax.experimental.array_serialization import serialization
39+
from packaging import version
3940

4041
from axlearn.common.utils import Tensor
4142

@@ -75,6 +76,7 @@ def shard_coordinate(self):
7576

7677
# Tuple (and thus hashable) representation of a slice object (start, end, step).
7778
_SliceTuple = tuple[Optional[int], Optional[int], Optional[int]]
79+
JAX_VERSION = version.parse(jax.__version__)
7880

7981

8082
def _slices_to_tuple(slices: list[slice]) -> tuple[_SliceTuple, ...]:
@@ -306,12 +308,13 @@ async def _async_serialize(
306308
and arr_inp.is_fully_addressable
307309
)
308310
# pylint: disable=protected-access
309-
if jax.__version__.startswith("0.8.") or jax.__version__ == "0.6.2":
311+
if JAX_VERSION >= version.parse("0.6.2"):
310312
spec_has_metadata = serialization.ts_impl._spec_has_metadata
311-
elif jax.__version__ == "0.5.3":
313+
elif JAX_VERSION >= version.parse("0.5.3"):
312314
spec_has_metadata = serialization._spec_has_metadata
313315
else:
314316
raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}")
317+
315318
if not spec_has_metadata(tensorstore_spec):
316319
# pylint: disable-next=protected-access
317320
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)
@@ -488,10 +491,14 @@ async def _async_deserialize(
488491
async def cb(index: array.Index, device: jax.Device):
489492
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
490493
restricted_domain = t.domain.intersect(requested_domain)
491-
estimate_read_memory_footprint = {
492-
"0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint,
493-
"0.5.3": lambda: serialization.estimate_read_memory_footprint,
494-
}[jax.__version__]()
494+
if JAX_VERSION >= version.parse("0.6.2"):
495+
estimate_read_memory_footprint = serialization.ts_impl.estimate_read_memory_footprint
496+
elif JAX_VERSION >= version.parse("0.5.3"):
497+
estimate_read_memory_footprint = serialization.estimate_read_memory_foot_print
498+
else:
499+
raise ValueError(
500+
f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer"
501+
)
495502
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
496503
# Limit the bytes read for every shard.
497504
await byte_limiter.wait_for_bytes(requested_bytes)
@@ -569,10 +576,13 @@ async def cb(index: array.Index, device: jax.Device):
569576
return result
570577

571578
# pylint: disable=protected-access
572-
create_async_array_from_callback = {
573-
"0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback,
574-
"0.5.3": lambda: serialization.create_async_array_from_callback,
575-
}[jax.__version__]()
579+
if JAX_VERSION >= version.parse("0.6.2"):
580+
create_async_array_from_callback = serialization.ts_impl._create_async_array_from_callback
581+
elif JAX_VERSION >= version.parse("0.5.3"):
582+
create_async_array_from_callback = serialization.create_async_array_from_callback
583+
else:
584+
raise ValueError("Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer.")
585+
576586
return await create_async_array_from_callback(shape, in_sharding, cb)
577587

578588

@@ -654,10 +664,14 @@ def serialize(
654664

655665
commit_futures = [[] for _ in range(len(tensorstore_specs))]
656666

657-
async_serialize = {
658-
"0.6.2": lambda: serialization.ts_impl.async_serialize,
659-
"0.5.3": lambda: serialization.async_serialize,
660-
}[jax.__version__]()
667+
if JAX_VERSION >= version.parse("0.6.2"):
668+
async_serialize = serialization.ts_impl.async_serialize
669+
elif JAX_VERSION >= version.parse("0.5.3"):
670+
async_serialize = serialization.async_serialize
671+
else:
672+
raise ValueError(
673+
f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer."
674+
)
661675

662676
# pylint: disable-next=redefined-outer-name
663677
async def _run_serializer():

0 commit comments

Comments
 (0)