From f17cab6bafc206c91948b7b86223ab09da86fda2 Mon Sep 17 00:00:00 2001 From: Shutong Li Date: Wed, 17 Dec 2025 11:30:11 -0800 Subject: [PATCH] Fix two issues that blocks training loop with continuous checkpoint enabled. PiperOrigin-RevId: 845850177 --- src/MaxText/checkpointing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 27fb674ff..c2d149659 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags from etils import epath from flax.training import train_state +import datetime import jax from MaxText import exceptions from MaxText import max_logging @@ -50,6 +51,7 @@ PersistentCheckpointOptions = emergency_checkpoint_manager.PersistentCheckpointOptions EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager +ASYNC_SAVE_TIMEOUT = datetime.timedelta(minutes=60) class GrainCheckpointHandler(PyGrainCheckpointHandler, ocp.CheckpointHandler): """A CheckpointHandler that allows specifying process_index and process_count.""" @@ -230,6 +232,11 @@ def create_orbax_checkpoint_manager( preservation_policy = preservation_policy_lib.LatestN( max_num_checkpoints_to_keep ) + async_options = None + if enable_continuous_checkpointing: + async_options = ocp.AsyncOptions( + timeout_secs=int(ASYNC_SAVE_TIMEOUT.total_seconds()), + ) manager = CheckpointManager( p, item_names=item_names, @@ -239,7 +246,8 @@ def create_orbax_checkpoint_manager( enable_async_checkpointing=use_async, save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, - ), + async_options=async_options, + ), logger=orbax_logger, ) @@ -711,7 +719,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= if config and config.enable_checkpointing: if ( force - or (step % config.checkpoint_period == 0) + or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing) or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0) ): blocking_until_ready_start = time.time()