Skip to content

Commit 4211739

Browse files
committed
test
1 parent 5ccd8f4 commit 4211739

File tree

4 files changed

+34
-35
lines changed

4 files changed

+34
-35
lines changed

Dockerfile

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,20 @@ ARG EXTRAS=
107107
# Install a custom jaxlib that includes backport of Pathways shared memory feature.
108108
# PR: https://github.com/openxla/xla/pull/31417
109109
# Needed until Jax is upgraded to 0.8.0 or newer.
110-
ARG INSTALL_PATHWAYS_JAXLIB=false
110+
ARG INSTALL_PATHWAYS_JAXLIB=true
111111

112112
# Ensure we install the TPU version, even if building locally.
113113
# Jax will fallback to CPU when run on a machine without TPU.
114114
RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean
115115
RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi
116+
117+
COPY jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl .
118+
119+
# 2. RUN the pip install command using the new, simple path *inside* the container
116120
RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \
117-
uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \
118-
--find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
121+
# uv pip install --prerelease=allow "jaxlib==0.6.2.dev20251020" \
122+
# --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
123+
uv pip install jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl; \
119124
fi
120125
COPY . .
121126

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -334,29 +334,29 @@ def _build_pathways_head_container(self) -> dict:
334334
}
335335
)
336336

337-
# pylint: disable=line-too-long
338-
env_list.append(
339-
{
340-
"name": "NUM_REPLICAS",
341-
"valueFrom": {
342-
"fieldRef": {
343-
"fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']"
344-
}
345-
},
346-
}
347-
)
337+
# # pylint: disable=line-too-long
338+
# env_list.append(
339+
# {
340+
# "name": "NUM_REPLICAS",
341+
# "valueFrom": {
342+
# "fieldRef": {
343+
# "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']"
344+
# }
345+
# },
346+
# }
347+
# )
348348
# pylint: enable=line-too-long
349349

350-
env_list.append(
351-
{
352-
"name": "REPLICA_ID",
353-
"valueFrom": {
354-
"fieldRef": {
355-
"fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']"
356-
}
357-
},
358-
}
359-
)
350+
# env_list.append(
351+
# {
352+
# "name": "REPLICA_ID",
353+
# "valueFrom": {
354+
# "fieldRef": {
355+
# "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']"
356+
# }
357+
# },
358+
# }
359+
# )
360360

361361
head_container["env"] = env_list
362362

73.2 MB
Binary file not shown.

pyproject.toml

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ core = [
4343
"tensorflow-metadata==1.17.2", # Otherwise Seqio will report no core for tfds
4444
"tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called".
4545
"tensorflow-metadata>=1.0.0", # Otherwise Seqio will report no core for tfds
46-
"tensorflow_text==2.17.0", # implied by seqio, but also used directly for text processing
46+
"tensorflow_text==2.19.0", # implied by seqio, but also used directly for text processing
4747
"tensorstore>=0.1.63", # used for supporting GDA checkpoints
4848
"toml", # for config management
4949
"typing-extensions==4.12.2",
@@ -71,15 +71,14 @@ dev = [
7171
"pylint==2.17.7",
7272
"pytest", # test runner
7373
"pytest-xdist", # pytest plugin for test parallelism
74-
"pytest-timeout", # pytest plugin for forcing timeout of tests
7574
"pytype==2022.4.22", # type checking
7675
"scikit-learn==1.5.2", # test-only
7776
# Fix AttributeError: module 'scipy.linalg' has no attribute 'tril' and related scipy import errors.
7877
"scipy==1.12.0",
7978
"sentencepiece != 0.1.92",
8079
"tqdm", # test-only
8180
"timm==0.6.12", # DiT Dependency test-only
82-
"torch>=2.1.1", # test-only
81+
"torch>=1.12.1", # test-only
8382
"torchvision==0.16.1", # test-only
8483
"safetensors<=0.5.3", # TODO: Remove once torch dependency is >=2.3.0
8584
"transformers==4.51.3", # test-only
@@ -199,7 +198,6 @@ line-length = 100
199198
target-version = 'py39'
200199

201200
[tool.pytest.ini_options]
202-
timeout = 300
203201
addopts = "-rs -s -p no:warnings --junitxml=test-results/testing.xml"
204202
markers = [
205203
"gs_login: tests needing GS login.",
@@ -219,13 +217,9 @@ junit_family="xunit2"
219217
line_length = 100
220218
profile = "black"
221219

222-
[tool.uv.pip]
220+
[tool.uv]
223221
find-links = [
222+
"https://storage.googleapis.com/axlearn-wheels/wheels.html",
224223
"https://storage.googleapis.com/jax-releases/libtpu_releases.html",
225224
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
226-
]
227-
228-
[tool.uv]
229-
override-dependencies = [
230-
"ml-dtypes>=0.5,<0.6",
231-
]
225+
]

0 commit comments

Comments
 (0)