Skip to content

Commit f6d6227

Browse files
lyglstGoogle-ML-Automation
authored andcommitted
Fix two issues that blocks training loop with continuous checkpoint enabled.
PiperOrigin-RevId: 845850177
1 parent 4b2f023 commit f6d6227

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/MaxText/checkpointing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import flags
2121
from etils import epath
2222
from flax.training import train_state
23+
import datetime
2324
import jax
2425
from MaxText import exceptions
2526
from MaxText import max_logging
@@ -50,6 +51,7 @@
5051
PersistentCheckpointOptions = emergency_checkpoint_manager.PersistentCheckpointOptions
5152
EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager
5253

54+
ASYNC_SAVE_TIMEOUT = datetime.timedelta(minutes=35)
5355

5456
class GrainCheckpointHandler(PyGrainCheckpointHandler, ocp.CheckpointHandler):
5557
"""A CheckpointHandler that allows specifying process_index and process_count."""
@@ -230,6 +232,11 @@ def create_orbax_checkpoint_manager(
230232
preservation_policy = preservation_policy_lib.LatestN(
231233
max_num_checkpoints_to_keep
232234
)
235+
async_options = None
236+
if enable_continuous_checkpointing:
237+
async_options = ocp.AsyncOptions(
238+
timeout_secs=int(ASYNC_SAVE_TIMEOUT.total_seconds()),
239+
)
233240
manager = CheckpointManager(
234241
p,
235242
item_names=item_names,
@@ -239,7 +246,8 @@ def create_orbax_checkpoint_manager(
239246
enable_async_checkpointing=use_async,
240247
save_decision_policy=save_decision_policy,
241248
preservation_policy=preservation_policy,
242-
),
249+
async_options=async_options,
250+
),
243251
logger=orbax_logger,
244252
)
245253

@@ -711,7 +719,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
711719
if config and config.enable_checkpointing:
712720
if (
713721
force
714-
or (step % config.checkpoint_period == 0)
722+
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
715723
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
716724
):
717725
blocking_until_ready_start = time.time()

0 commit comments

Comments
 (0)