|
36 | 36 | from jax._src import array, typing |
37 | 37 | from jax._src.layout import Layout |
38 | 38 | from jax.experimental.array_serialization import serialization |
| 39 | +from packaging import version |
39 | 40 |
|
40 | 41 | from axlearn.common.utils import Tensor |
41 | 42 |
|
@@ -75,6 +76,7 @@ def shard_coordinate(self): |
75 | 76 |
|
76 | 77 | # Tuple (and thus hashable) representation of a slice object (start, end, step). |
77 | 78 | _SliceTuple = tuple[Optional[int], Optional[int], Optional[int]] |
| 79 | +JAX_VERSION = version.parse(jax.__version__) |
78 | 80 |
|
79 | 81 |
|
80 | 82 | def _slices_to_tuple(slices: list[slice]) -> tuple[_SliceTuple, ...]: |
@@ -306,12 +308,13 @@ async def _async_serialize( |
306 | 308 | and arr_inp.is_fully_addressable |
307 | 309 | ) |
308 | 310 | # pylint: disable=protected-access |
309 | | - if jax.__version__.startswith("0.8.") or jax.__version__ == "0.6.2": |
| 311 | + if JAX_VERSION >= version.parse("0.6.2"): |
310 | 312 | spec_has_metadata = serialization.ts_impl._spec_has_metadata |
311 | | - elif jax.__version__ == "0.5.3": |
| 313 | + elif JAX_VERSION >= version.parse("0.5.3"): |
312 | 314 | spec_has_metadata = serialization._spec_has_metadata |
313 | 315 | else: |
314 | 316 | raise ValueError(f"Unsupported JAX version for spec_has_metadata: {jax.__version__}") |
| 317 | + |
315 | 318 | if not spec_has_metadata(tensorstore_spec): |
316 | 319 | # pylint: disable-next=protected-access |
317 | 320 | tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp) |
@@ -488,10 +491,14 @@ async def _async_deserialize( |
488 | 491 | async def cb(index: array.Index, device: jax.Device): |
489 | 492 | requested_domain = ts.IndexTransform(input_shape=shape)[index].domain |
490 | 493 | restricted_domain = t.domain.intersect(requested_domain) |
491 | | - estimate_read_memory_footprint = { |
492 | | - "0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint, |
493 | | - "0.5.3": lambda: serialization.estimate_read_memory_footprint, |
494 | | - }[jax.__version__]() |
| 494 | + if JAX_VERSION >= version.parse("0.6.2"): |
| 495 | + estimate_read_memory_footprint = serialization.ts_impl.estimate_read_memory_footprint |
| 496 | + elif JAX_VERSION >= version.parse("0.5.3"): |
| 497 | + estimate_read_memory_footprint = serialization.estimate_read_memory_foot_print |
| 498 | + else: |
| 499 | + raise ValueError( |
| 500 | + f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer" |
| 501 | + ) |
495 | 502 | requested_bytes = estimate_read_memory_footprint(t, restricted_domain) |
496 | 503 | # Limit the bytes read for every shard. |
497 | 504 | await byte_limiter.wait_for_bytes(requested_bytes) |
@@ -569,10 +576,13 @@ async def cb(index: array.Index, device: jax.Device): |
569 | 576 | return result |
570 | 577 |
|
571 | 578 | # pylint: disable=protected-access |
572 | | - create_async_array_from_callback = { |
573 | | - "0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback, |
574 | | - "0.5.3": lambda: serialization.create_async_array_from_callback, |
575 | | - }[jax.__version__]() |
| 579 | + if JAX_VERSION >= version.parse("0.6.2"): |
| 580 | + create_async_array_from_callback = serialization.ts_impl._create_async_array_from_callback |
| 581 | + elif JAX_VERSION >= version.parse("0.5.3"): |
| 582 | + create_async_array_from_callback = serialization.create_async_array_from_callback |
| 583 | + else: |
| 584 | + raise ValueError("Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer.") |
| 585 | + |
576 | 586 | return await create_async_array_from_callback(shape, in_sharding, cb) |
577 | 587 |
|
578 | 588 |
|
@@ -654,10 +664,14 @@ def serialize( |
654 | 664 |
|
655 | 665 | commit_futures = [[] for _ in range(len(tensorstore_specs))] |
656 | 666 |
|
657 | | - async_serialize = { |
658 | | - "0.6.2": lambda: serialization.ts_impl.async_serialize, |
659 | | - "0.5.3": lambda: serialization.async_serialize, |
660 | | - }[jax.__version__]() |
| 667 | + if JAX_VERSION >= version.parse("0.6.2"): |
| 668 | + async_serialize = serialization.ts_impl.async_serialize |
| 669 | + elif JAX_VERSION >= version.parse("0.5.3"): |
| 670 | + async_serialize = serialization.async_serialize |
| 671 | + else: |
| 672 | + raise ValueError( |
| 673 | + f"Unsupported JAX version: {JAX_VERSION}. Version must be 0.5.3 or newer." |
| 674 | + ) |
661 | 675 |
|
662 | 676 | # pylint: disable-next=redefined-outer-name |
663 | 677 | async def _run_serializer(): |
|
0 commit comments