From 0e44638b9b6021b328642cc4cc8a9216b3b8f764 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Mon, 29 Dec 2025 19:28:05 +0000 Subject: [PATCH] Fix sft + ga --- src/MaxText/sft_trainer.py | 5 ++- .../gradient_accumulation_test.py | 45 ++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/MaxText/sft_trainer.py b/src/MaxText/sft_trainer.py index bc60b32b1a..272d95d2dc 100644 --- a/src/MaxText/sft_trainer.py +++ b/src/MaxText/sft_trainer.py @@ -35,6 +35,7 @@ from MaxText import profiler from MaxText import pyconfig from MaxText import train_utils +from MaxText import sharding from MaxText.data_loader import DataLoader from MaxText.metric_logger import MetricLogger from MaxText.train import ( @@ -71,8 +72,10 @@ def train_loop(config, recorder, state=None): state, ) = setup_train_loop(config, recorder) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator + config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings ) with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): diff --git a/tests/integration_tests/gradient_accumulation_test.py b/tests/integration_tests/gradient_accumulation_test.py index 8e6db3043d..0fca7ac008 100644 --- a/tests/integration_tests/gradient_accumulation_test.py +++ b/tests/integration_tests/gradient_accumulation_test.py @@ -25,6 +25,7 @@ import os.path from MaxText.train import main as train_main +from MaxText.sft_trainer import main as sft_main from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT @@ -84,8 +85,14 @@ def test_grad_accumulate_same_loss(self): ): accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"] regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"] - print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True) - print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True) + print( + f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", + flush=True, + ) + print( + f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", + flush=True, + ) # Not identical due to an epsilon addition in loss denominator. np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01) @@ -96,8 +103,14 @@ def test_grad_accumulate_same_loss(self): ): accum_run_grad_norm = json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"] regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"] - print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True) - print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True) + print( + f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", + flush=True, + ) + print( + f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", + flush=True, + ) # Not identical due to an epsilon addition in loss denominator. np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01) @@ -109,10 +122,32 @@ def test_grad_accumulate_same_loss(self): accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"] regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"] print( - f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True + f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", + flush=True, ) print( f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}", flush=True, ) np.testing.assert_equal(accum_device_tflops, regular_device_tflops) + + @pytest.mark.integration_test + @pytest.mark.tpu_only + def test_sft_grad_accumulate_same_loss(self): + sft_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off). + "enable_checkpointing=False", + "enable_goodput_recording=False", + "base_emb_dim=256", + "base_num_decoder_layers=4", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "steps=3", + "gradient_accumulation_steps=2", + "use_sft=True", + ] + )