Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ batch_size: 1 # Batch sizes for each vol size.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into
shard_dim: 2 # DistConv param: dimension on which to shard
checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
Expand Down
1 change: 1 addition & 0 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, config_dict):
self.dataset_reuse_enforce_commit_id = config_dict[
"dataset_reuse_enforce_commit_id"
]
self.checkpoint_interval = config_dict["checkpoint_interval"]


class RunConfig(Config):
Expand Down
68 changes: 66 additions & 2 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ def cleanup_or_resume(self):
"train_mask_values"
]

# If we loaded a checkpoint (start_epoch > 1), we must ensure the CSV
# matches the state of that checkpoint.
if (
self.world_rank == 0
and self.start_epoch > 1
and os.path.exists(self.outfile_path)
):
self._truncate_stats_file(self.start_epoch)

# Set up the output file headers
headers = [
"epoch",
Expand All @@ -254,6 +263,60 @@ def cleanup_or_resume(self):
with open(self.outfile_path, "a", newline="") as outfile:
outfile.write(",".join(headers) + "\n")

def _truncate_stats_file(self, start_epoch):
"""
Scans the stats file and truncates it at the first occurrence of
an epoch >= start_epoch. This is O(1) memory and safe for large logs.
"""
self.log.info(
f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}"
)

try:
# Open in read+update mode ('r+') to allow seeking and truncating
with open(self.outfile_path, "r+") as f:
header = f.readline()
if not header:
return

# Identify the index of the 'epoch' column
headers = header.strip().split(",")
try:
epoch_idx = headers.index("epoch")
except ValueError:
epoch_idx = 0

while True:
# Save the current file position (start of the line)
current_pos = f.tell()
line = f.readline()

# End of file reached
if not line:
break

parts = line.strip().split(",")
try:
row_epoch = int(float(parts[epoch_idx]))

# If we find a row that is "from the future" (or the current restarting epoch)
if row_epoch >= start_epoch:
# Move pointer back to the start of this line
f.seek(current_pos)
# Cut the file off right here
f.truncate()
self.log.info(
f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})"
)
break
except (ValueError, IndexError):
# Skip malformed lines, or decide to stop.
# Usually safe to continue scanning.
pass

except Exception as e:
self.log.warning(f"Failed to truncate stats file: {e}")

def train(self):
"""
Execute model training
Expand Down Expand Up @@ -581,8 +644,9 @@ def train(self):
#
begin_code_region("checkpoint")

extras = {"train_mask_values": self.train_set.mask_values}
self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras)
if epoch % self.config.checkpoint_interval == 0:
extras = {"train_mask_values": self.train_set.mask_values}
self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras)

end_code_region("checkpoint")

Expand Down