Skip to content

Commit f419557

Browse files
committed
Adds pause-resume elasticity if the proxy backend is used.
Added the changes to the jobset for elastic training to enable elasticity. Added changes to launch_trainer so that the pause_resume decorator is used. Set logging.raiseExceptions=True so that DATA_LOSS errors that occur in debug/info/other log calls are raise exceptions immediately.
1 parent ae9a2c5 commit f419557

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
343343
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
344344
f"--server_port={_PATHWAYS_PROXY_PORT}",
345345
f"--gcs_scratch_location={staging_location}",
346+
f"--num_elastic_slices={cfg.accelerator.num_replicas}",
346347
]
347348
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
348349

@@ -604,14 +605,21 @@ def _build_pathways_worker_job(
604605
annotations.update(
605606
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}
606607
)
608+
# Default value for suspend and resume.
609+
# References:
610+
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
611+
backoffLimit = system.vms_per_slice * _PATHWAYS_BACK_OFF_LIMIT
612+
613+
# This backoffLimit is just for verifying elastic pause-resume
614+
backoffLimit *= 1000
607615

608616
spec = dict(
609617
parallelism=system.vms_per_slice,
610618
completions=system.vms_per_slice,
611619
# Default value for suspend and resume.
612620
# References:
613621
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
614-
backoffLimit=system.vms_per_slice * _PATHWAYS_BACK_OFF_LIMIT,
622+
backoffLimit=backoffLimit,
615623
template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index),
616624
)
617625
worker_job = dict(

axlearn/common/launch_trainer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Utilities to launch a trainer."""
44

55
import json
6+
import logging as py_logging
67
import os
78
from typing import Any, Optional
89

@@ -15,6 +16,8 @@
1516
from axlearn.common.utils import MeshShape, get_data_dir, infer_mesh_shape
1617
from axlearn.experiments import TrainerConfigFn, get_named_trainer_config
1718

19+
py_logging.raiseException = True
20+
1821
# Trainer-specific flags.
1922
flags.DEFINE_string(
2023
"module",
@@ -158,8 +161,24 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
158161
f,
159162
)
160163

161-
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
162-
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
163-
output = trainer.run(prng_key)
164-
measurement.record_event(measurement.Event.END_JOB)
164+
def run() -> Any:
165+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
166+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
167+
output = trainer.run(prng_key)
168+
measurement.record_event(measurement.Event.END_JOB)
169+
return output
170+
171+
if FLAGS.jax_backend == "proxy":
172+
# pylint: disable-next=import-error,import-outside-toplevel
173+
from pathwaysutils.elastic import manager
174+
elastic_manager = manager.Manager()
175+
max_retries = 5
176+
timeout = 10 * 60 # ten minutes
177+
run = elastic_manager.pause_resume(
178+
max_retries=max_retries,
179+
timeout=timeout,
180+
)(run)
181+
182+
output = run()
183+
165184
return output

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ gcp = [
108108
tpu = [
109109
"axlearn[gcp]",
110110
"jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p.
111-
"pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator.
111+
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils.git@test_796970321", # For JAX+Pathways single-controller accelerator coordinator.
112112
]
113113
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
114114
vertexai_tensorboard = [

0 commit comments

Comments
 (0)