Skip to content

Commit 272c430

Browse files
committed
new jaxlib dev version
1 parent 4211739 commit 272c430

File tree

3 files changed

+5
-14
lines changed

3 files changed

+5
-14
lines changed

Dockerfile

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,8 @@ RUN pip install -qq --upgrade pip && \
4646
FROM base AS ci
4747

4848
# TODO(markblee): Remove gcp,vertexai_tensorboard from CI.
49-
<<<<<<< HEAD
5049
RUN uv pip install -qq .[core,audio,orbax,dev,gcp,vertexai_tensorboard] && \
5150
uv cache clean
52-
=======
53-
RUN uv pip install .[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api] && uv cache clean
54-
>>>>>>> 4109ab0 (Move UV_FIND_LINKS to pyproject.toml)
5551
COPY . .
5652

5753
# Defaults to an empty string, i.e. run pytest against all files.
@@ -70,11 +66,7 @@ FROM base AS bastion
7066
# TODO(markblee): Consider copying large directories separately, to cache more aggressively.
7167
# TODO(markblee): Is there a way to skip the "production" deps?
7268
COPY . /root/
73-
<<<<<<< HEAD
7469
RUN uv pip install -qq .[core,gcp,vertexai_tensorboard] && uv cache clean
75-
=======
76-
RUN uv pip install .[core,gcp,vertexai_tensorboard] && uv cache clean
77-
>>>>>>> 4109ab0 (Move UV_FIND_LINKS to pyproject.toml)
7870

7971
################################################################################
8072
# Dataflow container spec. #
@@ -85,11 +77,7 @@ FROM base AS dataflow
8577
# Beam workers default to creating a new virtual environment on startup. Instead, we want them to
8678
# pickup the venv setup above. An alternative is to install into the global environment.
8779
ENV RUN_PYTHON_SDK_IN_DEFAULT_ENVIRONMENT=1
88-
<<<<<<< HEAD
8980
RUN uv pip install -qq .[core,gcp,dataflow] && uv cache clean
90-
=======
91-
RUN uv pip install .[core,gcp,dataflow] && uv cache clean
92-
>>>>>>> 4109ab0 (Move UV_FIND_LINKS to pyproject.toml)
9381
COPY . .
9482

9583
# Dataflow workers can't start properly if the entrypoint is not set

axlearn/common/array_serialization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ async def _async_serialize(
307307
)
308308
# pylint: disable=protected-access
309309
spec_has_metadata = {
310+
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._spec_has_metadata,
310311
"0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
311312
"0.5.3": lambda: serialization._spec_has_metadata,
312313
}[jax.__version__]()
@@ -487,6 +488,7 @@ async def cb(index: array.Index, device: jax.Device):
487488
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
488489
restricted_domain = t.domain.intersect(requested_domain)
489490
estimate_read_memory_footprint = {
491+
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.estimate_read_memory_footprint,
490492
"0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint,
491493
"0.5.3": lambda: serialization.estimate_read_memory_footprint,
492494
}[jax.__version__]()
@@ -568,6 +570,7 @@ async def cb(index: array.Index, device: jax.Device):
568570

569571
# pylint: disable=protected-access
570572
create_async_array_from_callback = {
573+
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._create_async_array_from_callback,
571574
"0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback,
572575
"0.5.3": lambda: serialization.create_async_array_from_callback,
573576
}[jax.__version__]()
@@ -653,6 +656,7 @@ def serialize(
653656
commit_futures = [[] for _ in range(len(tensorstore_specs))]
654657

655658
async_serialize = {
659+
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.async_serialize,
656660
"0.6.2": lambda: serialization.ts_impl.async_serialize,
657661
"0.5.3": lambda: serialization.async_serialize,
658662
}[jax.__version__]()

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ gpu = [
144144
"nvidia-ml-py>=12.560.30",
145145
# pin nccl version, otherwise jax[cuda12] will pull latest version
146146
"nvidia-nccl-cu12==2.27.5",
147-
"nvidia-cudnn-cu12>=9.8.0.87" # Pin CuDNN to at least 9.8 for Jax >= 0.6.2
148-
"nvidia-cudnn-cu12>=9.8.0.87" # Pin CuDNN to at least 9.8 for Jax >= 0.6.2
147+
"nvidia-cudnn-cu12>=9.8.0.87", # Pin CuDNN to at least 9.8 for Jax >= 0.6.2
149148
]
150149
# Open API inference.
151150
open_api = [

0 commit comments

Comments
 (0)