From 8be945b5b780e8a088883d103bde6eb8193903a2 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 28 Jan 2026 12:52:37 -0800 Subject: [PATCH 1/5] Continue if checkpointing fails --- ScaFFold/utils/checkpointing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 0bab949..92c5203 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -257,14 +257,22 @@ def save_checkpoint( def _write_to_disk(state_dict, last_path, best_path, is_best): """Worker function to perform actual disk I/O.""" # Save 'last' - torch.save(state_dict, last_path) + try: + torch.save(state_dict, last_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) # Save 'best' (copy logic) if is_best: # Copy is often faster than re-serializing if last_path.exists(): shutil.copyfile(last_path, best_path) else: - torch.save(state_dict, best_path) + try: + torch.save(state_dict, best_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) def _transfer_dict_to_cpu(self, obj): """Recursively move tensors to CPU.""" From 3d31a270a9e2aa86d8b53a048d3e00e475b548a2 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 28 Jan 2026 12:53:02 -0800 Subject: [PATCH 2/5] Fixes for new distconv --- ScaFFold/utils/trainer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c76646a..5307a1f 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -295,7 +295,11 @@ def train(self): masks_pred_dc = self.model(images_dc) # Convert predictions for loss - if images.size(0) < ps.num_shards: + if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: + n_shards = ps.num_shards[0] + else: + n_shards = ps.num_shards + if images.size(0) < n_shards: # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs masks_pred = masks_pred_dc.to_replicate() labels_for_loss = true_masks @@ -419,11 +423,11 @@ def train(self): true_masks_ddp = ( DTensor.from_local( true_masks_dp, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh["dc4"], placements=[Replicate()], ) .redistribute( - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh["dc4"], placements=[Shard(0)], ) .to_local() From 35db2d546bacb9bb1eec9db7c74115f896200534 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 28 Jan 2026 12:53:18 -0800 Subject: [PATCH 3/5] enable running 1 epoch --- ScaFFold/worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 33f8949..07bedf0 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -239,7 +239,8 @@ def main(kwargs_dict: dict = {}): outfile_path = trainer.outfile_path train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) total_train_time = train_data["epoch_duration"].sum() - total_epochs = train_data["epoch"][-1] + epochs = np.atleast_1d(train_data["epoch"]) + total_epochs = int(epochs[-1]) log.info( f"Benchmark run at scale {config.problem_scale} complete. \n\ Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." From 95f2a0b45578d67543cbb990cc7069477164c64e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 11:24:25 -0800 Subject: [PATCH 4/5] Update trainer.py --- ScaFFold/utils/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 5307a1f..0e1eff9 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -308,7 +308,7 @@ def train(self): masks_pred = masks_pred_dc.to_ddp() dt_labels = distribute_tensor( true_masks, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], placements=[Shard(0)], ) labels_for_loss = dt_labels.to_local() @@ -423,11 +423,11 @@ def train(self): true_masks_ddp = ( DTensor.from_local( true_masks_dp, - device_mesh=ps.device_mesh["dc4"], + device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], placements=[Replicate()], ) .redistribute( - device_mesh=ps.device_mesh["dc4"], + device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], placements=[Shard(0)], ) .to_local() From b13d9b64c9a9db5de88c6eb114678572e76a5fc6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 13:41:20 -0800 Subject: [PATCH 5/5] Update trainer.py --- ScaFFold/utils/trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 0e1eff9..4196c3a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -308,7 +308,9 @@ def train(self): masks_pred = masks_pred_dc.to_ddp() dt_labels = distribute_tensor( true_masks, - device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) labels_for_loss = dt_labels.to_local() @@ -423,11 +425,15 @@ def train(self): true_masks_ddp = ( DTensor.from_local( true_masks_dp, - device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Replicate()], ) .redistribute( - device_mesh=ps.device_mesh[f"dc{self.config.shard_dim+2}"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) .to_local()