diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 0bab949..92c5203 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -257,14 +257,22 @@ def save_checkpoint( def _write_to_disk(state_dict, last_path, best_path, is_best): """Worker function to perform actual disk I/O.""" # Save 'last' - torch.save(state_dict, last_path) + try: + torch.save(state_dict, last_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) # Save 'best' (copy logic) if is_best: # Copy is often faster than re-serializing if last_path.exists(): shutil.copyfile(last_path, best_path) else: - torch.save(state_dict, best_path) + try: + torch.save(state_dict, best_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) def _transfer_dict_to_cpu(self, obj): """Recursively move tensors to CPU.""" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c76646a..4196c3a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -295,7 +295,11 @@ def train(self): masks_pred_dc = self.model(images_dc) # Convert predictions for loss - if images.size(0) < ps.num_shards: + if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: + n_shards = ps.num_shards[0] + else: + n_shards = ps.num_shards + if images.size(0) < n_shards: # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs masks_pred = masks_pred_dc.to_replicate() labels_for_loss = true_masks @@ -304,7 +308,9 @@ def train(self): masks_pred = masks_pred_dc.to_ddp() dt_labels = distribute_tensor( true_masks, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) labels_for_loss = dt_labels.to_local() @@ -419,11 +425,15 @@ def train(self): true_masks_ddp = ( DTensor.from_local( true_masks_dp, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Replicate()], ) .redistribute( - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) .to_local() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 33f8949..07bedf0 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -239,7 +239,8 @@ def main(kwargs_dict: dict = {}): outfile_path = trainer.outfile_path train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) total_train_time = train_data["epoch_duration"].sum() - total_epochs = train_data["epoch"][-1] + epochs = np.atleast_1d(train_data["epoch"]) + total_epochs = int(epochs[-1]) log.info( f"Benchmark run at scale {config.problem_scale} complete. \n\ Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs."