Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from MaxText import profiler
from MaxText import pyconfig
from MaxText import train_utils
from MaxText import sharding
from MaxText.data_loader import DataLoader
from MaxText.metric_logger import MetricLogger
from MaxText.train import (
Expand Down Expand Up @@ -71,8 +72,10 @@ def train_loop(config, recorder, state=None):
state,
) = setup_train_loop(config, recorder)

params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings
)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down
45 changes: 40 additions & 5 deletions tests/integration_tests/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import os.path

from MaxText.train import main as train_main
from MaxText.sft_trainer import main as sft_main
from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT


Expand Down Expand Up @@ -84,8 +85,14 @@ def test_grad_accumulate_same_loss(self):
):
accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"]
regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"]
print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True)
print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True)
print(
f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}",
flush=True,
)
print(
f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}",
flush=True,
)
# Not identical due to an epsilon addition in loss denominator.
np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01)

Expand All @@ -96,8 +103,14 @@ def test_grad_accumulate_same_loss(self):
):
accum_run_grad_norm = json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"]
regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"]
print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True)
print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True)
print(
f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}",
flush=True,
)
print(
f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}",
flush=True,
)
# Not identical due to an epsilon addition in loss denominator.
np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01)

Expand All @@ -109,10 +122,32 @@ def test_grad_accumulate_same_loss(self):
accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"]
regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"]
print(
f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True
f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}",
flush=True,
)
print(
f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}",
flush=True,
)
np.testing.assert_equal(accum_device_tflops, regular_device_tflops)

@pytest.mark.integration_test
@pytest.mark.tpu_only
def test_sft_grad_accumulate_same_loss(self):
sft_main(
[
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
"base_output_directory=gs://runner-maxtext-logs",
"dataset_path=gs://maxtext-dataset",
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off).
"enable_checkpointing=False",
"enable_goodput_recording=False",
"base_emb_dim=256",
"base_num_decoder_layers=4",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
"steps=3",
"gradient_accumulation_steps=2",
"use_sft=True",
]
)