Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ grain_worker_count: 1
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard

# Training steps per loop
steps_per_loop: 100

# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
Expand Down
44 changes: 30 additions & 14 deletions MaxText/ray_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import logging
import ray

from absl import app
from ray_tpu import RayTpuManager
from ray.job_submission import JobSubmissionClient
from trainer import MaxTextTrainer

import logging
import os
import argparse
import pyconfig
from typing import Sequence, Optional
from absl import app

from typing import Sequence


#### Configurations
Expand Down Expand Up @@ -46,12 +42,32 @@ def get_job_submission_id() -> str:
return [job.submission_id for job in jobs if job.job_id == current_job_id][0]


def get_steps_values(args):
"""
Extracts the values of 'steps' and 'steps_per_loop' from args.

Args:
args: A list of key=value as strings.

Returns:
A tuple containing the values of 'steps' and 'steps_per_loop' as integers.
Returns (None, None) if not found.
"""
steps = None
steps_per_loop = None
for item in args:
if item.startswith('steps='):
steps = int(item.split('=')[1])
elif item.startswith('steps_per_loop='):
steps_per_loop = int(item.split('=')[1])
return steps, steps_per_loop


def main(argv: Sequence[str]):
ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
run_name = get_job_submission_id()
logging.info("Got args: %s", argv)
logging.info("This run name: %s", run_name)

tpu_resources = RayTpuManager.get_available_resources()
num_detected_tpu_types = len(tpu_resources.keys())
if num_detected_tpu_types == 0:
Expand Down Expand Up @@ -84,16 +100,16 @@ def main(argv: Sequence[str]):
raise e

logging.info("Initialization complete. Starting MaxText training...")
total_steps = 50 #int(args.total_steps)
steps_per_loop = 100 #int(args.steps_per_loop)
steps = 0
steps, steps_per_loop = get_steps_values(argv)
logging.info(f"KubeRay training running for total steps: {steps}, steps per loop: {steps_per_loop}")
steps_counter = 0

while steps < total_steps:
while steps_counter < steps:
logging.info("Training from step %d to %d.", steps, steps_per_loop)

try:
r = ray.get([actor.train.remote(num_steps=steps_per_loop) for actor in actors])
steps = r[0]
steps_counter = r[0]
except Exception as e:
logging.error("Caught error during training: %s", e)
logging.error("Shutting down...")
Expand Down