2020from absl import flags
2121from etils import epath
2222from flax .training import train_state
23+ import datetime
2324import jax
2425from MaxText import exceptions
2526from MaxText import max_logging
5051PersistentCheckpointOptions = emergency_checkpoint_manager .PersistentCheckpointOptions
5152EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager .ReplicatorCheckpointManager
5253
54+ ASYNC_SAVE_TIMEOUT = datetime .timedelta (minutes = 35 )
5355
5456class GrainCheckpointHandler (PyGrainCheckpointHandler , ocp .CheckpointHandler ):
5557 """A CheckpointHandler that allows specifying process_index and process_count."""
@@ -230,6 +232,11 @@ def create_orbax_checkpoint_manager(
230232 preservation_policy = preservation_policy_lib .LatestN (
231233 max_num_checkpoints_to_keep
232234 )
235+ async_options = None
236+ if enable_continuous_checkpointing :
237+ async_options = ocp .AsyncOptions (
238+ timeout_secs = int (ASYNC_SAVE_TIMEOUT .total_seconds ()),
239+ )
233240 manager = CheckpointManager (
234241 p ,
235242 item_names = item_names ,
@@ -239,7 +246,8 @@ def create_orbax_checkpoint_manager(
239246 enable_async_checkpointing = use_async ,
240247 save_decision_policy = save_decision_policy ,
241248 preservation_policy = preservation_policy ,
242- ),
249+ async_options = async_options ,
250+ ),
243251 logger = orbax_logger ,
244252 )
245253
@@ -711,7 +719,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
711719 if config and config .enable_checkpointing :
712720 if (
713721 force
714- or (step % config .checkpoint_period == 0 )
722+ or (step % config .checkpoint_period == 0 and not config . enable_continuous_checkpointing )
715723 or (config .enable_emergency_checkpoint and step % config .local_checkpoint_period == 0 )
716724 ):
717725 blocking_until_ready_start = time .time ()
0 commit comments