From 2cc6b3f90cdc1acc68bd35ffdb5eff12d6761111 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 17:44:13 -0800 Subject: [PATCH 1/3] set checkpoint interval --- ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/trainer.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 9ba4bc3..f6d6a17 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -11,6 +11,7 @@ batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: 2 # DistConv param: number of shards to divide the tensor into shard_dim: 2 # DistConv param: dimension on which to shard +checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 06ee76d..1f3e3a6 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.checkpoint_interval = config_dict["checkpoint_interval"] class RunConfig(Config): diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4196c3a..c057452 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -581,8 +581,9 @@ def train(self): # begin_code_region("checkpoint") - extras = {"train_mask_values": self.train_set.mask_values} - self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) + if epoch % self.config.checkpoint_interval == 0: + extras = {"train_mask_values": self.train_set.mask_values} + self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") From 1b4863c1b01e6f6e74523f96654e1bb4adc1c18b Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Fri, 30 Jan 2026 11:40:31 -0800 Subject: [PATCH 2/3] truncate stats csv when loading from checkpoint if checkpoint is behind latest CSV entries --- ScaFFold/utils/trainer.py | 54 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c057452..5a9f426 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -239,6 +239,11 @@ def cleanup_or_resume(self): "train_mask_values" ] + # If we loaded a checkpoint (start_epoch > 1), we must ensure the CSV + # matches the state of that checkpoint. + if self.world_rank == 0 and self.start_epoch > 1 and os.path.exists(self.outfile_path): + self._truncate_stats_file(self.start_epoch) + # Set up the output file headers headers = [ "epoch", @@ -254,6 +259,55 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") + def _truncate_stats_file(self, start_epoch): + """ + Scans the stats file and truncates it at the first occurrence of + an epoch >= start_epoch. This is O(1) memory and safe for large logs. + """ + self.log.info(f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}") + + try: + # Open in read+update mode ('r+') to allow seeking and truncating + with open(self.outfile_path, "r+") as f: + header = f.readline() + if not header: + return + + # Identify the index of the 'epoch' column + headers = header.strip().split(",") + try: + epoch_idx = headers.index("epoch") + except ValueError: + epoch_idx = 0 + + while True: + # Save the current file position (start of the line) + current_pos = f.tell() + line = f.readline() + + # End of file reached + if not line: + break + + parts = line.strip().split(",") + try: + row_epoch = int(float(parts[epoch_idx])) + + # If we find a row that is "from the future" (or the current restarting epoch) + if row_epoch >= start_epoch: + # Move pointer back to the start of this line + f.seek(current_pos) + # Cut the file off right here + f.truncate() + self.log.info(f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})") + break + except (ValueError, IndexError): + # Skip malformed lines, or decide to stop. + # Usually safe to continue scanning. + pass + + except Exception as e: + self.log.warning(f"Failed to truncate stats file: {e}") def train(self): """ Execute model training From 5930f51238454fc84ebfbefb3b71c507c2213a63 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Feb 2026 18:40:36 -0800 Subject: [PATCH 3/3] lint --- ScaFFold/utils/trainer.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 5a9f426..ac92908 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -241,7 +241,11 @@ def cleanup_or_resume(self): # If we loaded a checkpoint (start_epoch > 1), we must ensure the CSV # matches the state of that checkpoint. - if self.world_rank == 0 and self.start_epoch > 1 and os.path.exists(self.outfile_path): + if ( + self.world_rank == 0 + and self.start_epoch > 1 + and os.path.exists(self.outfile_path) + ): self._truncate_stats_file(self.start_epoch) # Set up the output file headers @@ -261,11 +265,13 @@ def cleanup_or_resume(self): def _truncate_stats_file(self, start_epoch): """ - Scans the stats file and truncates it at the first occurrence of + Scans the stats file and truncates it at the first occurrence of an epoch >= start_epoch. This is O(1) memory and safe for large logs. """ - self.log.info(f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}") - + self.log.info( + f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}" + ) + try: # Open in read+update mode ('r+') to allow seeking and truncating with open(self.outfile_path, "r+") as f: @@ -284,30 +290,33 @@ def _truncate_stats_file(self, start_epoch): # Save the current file position (start of the line) current_pos = f.tell() line = f.readline() - + # End of file reached if not line: break - + parts = line.strip().split(",") try: row_epoch = int(float(parts[epoch_idx])) - + # If we find a row that is "from the future" (or the current restarting epoch) if row_epoch >= start_epoch: # Move pointer back to the start of this line f.seek(current_pos) # Cut the file off right here f.truncate() - self.log.info(f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})") + self.log.info( + f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})" + ) break except (ValueError, IndexError): - # Skip malformed lines, or decide to stop. + # Skip malformed lines, or decide to stop. # Usually safe to continue scanning. pass - + except Exception as e: self.log.warning(f"Failed to truncate stats file: {e}") + def train(self): """ Execute model training