Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
import requests
from invoke.context import Context

from nemo_run.config import get_nemorun_home
from nemo_run.config import RUNDIR_NAME, get_nemorun_home
from nemo_run.core.execution.base import Executor, ExecutorMacros
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun
from nemo_run.core.execution.utils import fill_template
from nemo_run.core.frontend.console.api import CONSOLE
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager

Expand Down Expand Up @@ -461,6 +464,24 @@ def cancel(self, job_id: str):
response.text,
)

def _setup_launcher(self):
super()._setup_launcher()
launcher = self.launcher
if launcher and isinstance(launcher, (FaultTolerance, Torchrun)):
self.torchrun_nproc_per_node = self.nprocs_per_node
self.ntasks_per_node = 1
CONSOLE.log(
f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}"
)

if launcher and isinstance(launcher, FaultTolerance):
base_dir = os.path.join(self.job_dir, Path(self.job_dir).name)
launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml")
launcher.finished_flag_file = os.path.join(
"/", RUNDIR_NAME, f"{self.job_name}_finished_flag"
)
launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results")

def cleanup(self, handle: str): ...

def assign(
Expand Down Expand Up @@ -556,3 +577,55 @@ def _default_headers(self, token: Optional[str] = None) -> dict:
if token:
headers["Authorization"] = f"Bearer {token}"
return headers


@dataclass(kw_only=True)
class DGXCloudRequest:
launch_cmd: list[str]
jobs: list[str]
executor: DGXCloudExecutor
max_retries: int
extra_env: dict[str, str]
launcher: Optional[Launcher] = None

def materialize(self) -> str:
"""Creates the content of a DGXC entrypoint script."""

# 1. Environment Variables
# Combine executor defaults with extra envs
env_vars = []
full_env_vars = self.executor.env_vars | self.extra_env
for key, value in full_env_vars.items():
env_vars.append(f"export {key.upper()}={value}")

# 3. Prepare Template Variables
vars_to_fill = {
"max_retries": self.max_retries,
"env_vars": env_vars,
"training_command": " ".join(self.launch_cmd),
"ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance),
}

# 4. Fault Tolerance Injection
if self.launcher and isinstance(self.launcher, FaultTolerance):
assert (
self.launcher.cfg_path
and self.launcher.finished_flag_file
and self.launcher.job_results_file
), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file"

vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path
vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file
vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file

# Render the template
entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill)
return entrypoint_script

def __repr__(self) -> str:
return f"""# DGXC Entrypoint Script Request
# Executor: {self.executor.__class__.__name__}
# Jobs: {self.jobs}
# ---------------------------------------------------
{self.materialize()}
"""
31 changes: 31 additions & 0 deletions nemo_run/core/execution/templates/dgxc.sh.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{%- import "ft_launcher_dgxc.j2" as fault_tolerance -%}
#!/bin/bash

set -evx # Print commands, but DO NOT exit immediately on error (we handle that below)
export PYTHONUNBUFFERED=1
export TORCHX_MAX_RETRIES={{max_retries}}

{%- for env_var in env_vars %}
{{env_var}}
{%- endfor %}

{%- if ft_enabled %}
{{ fault_tolerance.ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) }}
{%- endif %}

echo "Starting training command..."
set +e # Turn off auto-exit so we can capture the code

{{ training_command }}

exitcode=$?
set -e

echo "Main command exited with code $exitcode"

{%- if ft_enabled %}
{{ fault_tolerance.ft_launcher_teardown() }}
{%- else %}

exit $exitcode
{%- endif %}
24 changes: 24 additions & 0 deletions nemo_run/core/execution/templates/ft_launcher_dgxc.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{% macro ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) -%}
# This script uses experimental fault tolerance launcher
# Fault tolerance related items
export FAULT_TOL_CFG_PATH="{{fault_tol_cfg_path}}"
export FAULT_TOL_FINISHED_FLAG_FILE="{{fault_tol_finished_flag_file}}"

JOB_RESULTS_FILE="{{fault_tol_job_results_file}}"

is_training_finished() {
test -f "$(dirname $JOB_RESULTS_FILE)/$(basename $FAULT_TOL_FINISHED_FLAG_FILE)"
}

if is_training_finished ; then
echo "Training is finished";
exit 0;
else
rm -f "$FAULT_TOL_FINISHED_FLAG_FILE" "$JOB_RESULTS_FILE"
fi

{%- endmacro %}

{% macro ft_launcher_teardown() -%}
exit $exitcode
{%- endmacro %}
2 changes: 1 addition & 1 deletion nemo_run/core/execution/templates/slurm.sh.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{%- import "ft_launcher.j2" as fault_tolerance -%}
{%- import "ft_launcher_slurm.j2" as fault_tolerance -%}
#!/bin/bash
#
# Generated by NeMo Run
Expand Down
13 changes: 8 additions & 5 deletions nemo_run/run/torchx_backend/components/ft_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,34 @@ def ft_launcher(
):
if workload_check_interval:
ft_args += [
"--ft-param-workload_check_interval",
"--ft-workload_check_interval",
str(workload_check_interval),
]

if initial_rank_heartbeat_timeout:
ft_args += [
"--ft-param-initial_rank_heartbeat_timeout",
"--ft-initial_rank_heartbeat_timeout",
str(initial_rank_heartbeat_timeout),
]

if rank_heartbeat_timeout:
ft_args += [
"--ft-param-rank_heartbeat_timeout",
"--ft-rank_heartbeat_timeout",
str(rank_heartbeat_timeout),
]

if rank_termination_signal:
ft_args += ["--ft-param-rank_termination_signal", rank_termination_signal]
ft_args += ["--ft-rank_termination_signal", rank_termination_signal]

if log_level:
ft_args += ["--ft-param-log_level", log_level]
ft_args += ["--ft-log_level", log_level]

if max_restarts:
ft_args += ["--max-restarts", str(max_restarts)]

if dgxc is True:
ft_args += ["--ft-use-infra-group-rank", "False"]

else:
ft_args = ["--ignore-missing-fault-tol-cfg"]

Expand Down
1 change: 1 addition & 0 deletions nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
log_level=launcher.log_level,
max_retries=executor.retries,
max_restarts=launcher.max_restarts,
dgxc=isinstance(executor, DGXCloudExecutor),
use_env=use_env,
)
else:
Expand Down
23 changes: 21 additions & 2 deletions nemo_run/run/torchx_backend/schedulers/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from nemo_run.config import get_nemorun_home
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudRequest, DGXCloudState
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin

Expand Down Expand Up @@ -109,6 +109,23 @@ def _submit_dryrun( # type: ignore
role = values.apply(role)

cmd = [role.entrypoint] + role.args

req = DGXCloudRequest(
launch_cmd=cmd,
jobs=[role.name],
executor=executor,
max_retries=role.max_retries,
extra_env=role.env,
launcher=executor.get_launcher(),
)

# Write and copy sbatch script
path = os.path.join(executor.experiment_dir, "torchrun_job.sh")
script = req.materialize()

with open(path, "w") as f:
f.write(script)

return AppDryRunInfo(
DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name),
# Minimal function to show the config, if any
Expand All @@ -128,7 +145,9 @@ def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str:

# The DGXExecutor's launch call typically returns (job_id, handle).
# We'll call it without additional parameters here.
job_id, status = executor.launch(name=req.name, cmd=req.cmd)
cmd = os.path.join(executor.experiment_dir, "torchrun_job.sh")
req.launch_cmd = ["bash", cmd]
job_id, status = executor.launch(name=req.name, cmd=req.launch_cmd)
if not job_id:
raise RuntimeError("Failed scheduling run on DGX: no job_id returned")

Expand Down
Loading