From 02487d5d0af8a12b4ef6dde5ad3f091f5ad6c5eb Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 21 Jan 2026 18:06:48 -0800 Subject: [PATCH 1/7] init --- ScaFFold/configs/benchmark_default.yml | 7 +++--- ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/trainer.py | 34 ++++++++++++++++++++------ ScaFFold/worker.py | 3 +++ 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 9ba4bc3..bf8545f 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,14 +9,14 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into +num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: 2 # DistConv param: dimension on which to shard # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. -epochs: 2000 # Number of training epochs. +epochs: -1 # Number of training epochs. learning_rate: .0001 # Learning rate for training. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. @@ -29,4 +29,5 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_epochs: 1 # How many warmup epochs before training -dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. \ No newline at end of file +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 \ No newline at end of file diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 06ee76d..4f7d458 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.target_dice = config_dict["target_dice"] class RunConfig(Config): diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4196c3a..4af34fb 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -40,7 +40,7 @@ # Local from ScaFFold.utils.evaluate import evaluate -from ScaFFold.utils.perf_measure import begin_code_region, end_code_region +from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -341,9 +341,18 @@ def train(self): end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + + epoch = 1 + dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: start = time.time() - for epoch in range(self.start_epoch, self.config.epochs + 1): + while dice_score_train < self.config.target_dice: + if self.config.epochs != -1 and epoch > self.config.epochs: + print( + f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)." + ) + break + # DistConv ParallelStrategy ps = getattr(self.config, "_parallel_strategy", None) if ps is None: @@ -364,10 +373,15 @@ def train(self): self.val_loader.sampler.set_epoch(epoch) self.model.train() + estr = ( + f"{epoch}" + if self.config.epochs == -1 + else f"{epoch}/{self.config.epochs}" + ) with tqdm( total=self.n_train // self.world_size, desc=f"({os.path.basename(self.config.run_dir)}) \ - Epoch {epoch}/{self.config.epochs}", + Epoch {estr}", unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: @@ -586,8 +600,12 @@ def train(self): end_code_region("checkpoint") - if val_score >= 0.95: - self.log.info( - f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." - ) - return 0 + dice_score_train = dice_sum + epoch += 1 + + adiak_value("final_epochs", epoch) + # if val_score >= 0.95: + # self.log.info( + # f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." + # ) + # return 0 diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..e591d18 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -246,6 +246,9 @@ def main(kwargs_dict: dict = {}): Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." ) + # solve hang? + dist.barrier() + # # Generate plots # From 40d7760a243b9a4aa240bdf5ba7225f8aa0eb45d Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 21 Jan 2026 19:19:23 -0800 Subject: [PATCH 2/7] debug --- ScaFFold/utils/perf_measure.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index e0ac1a2..8fdec7f 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -28,6 +28,7 @@ _CALI_PERF_ENABLED = True except Exception: + raise print("User requested Caliper annotations, but could not import Caliper") elif ( TORCH_PERF_ENV_VAR in os.environ From 2f18e4d58a42d86bb4f78d094058fba73f9fa7a6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 21 Jan 2026 23:24:25 -0800 Subject: [PATCH 3/7] testing --- ScaFFold/worker.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index e591d18..ff4ac8a 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -247,7 +247,11 @@ def main(kwargs_dict: dict = {}): ) # solve hang? - dist.barrier() + if os.getenv("SKIP_DIST_BARRIERS") != "1": + torch.cuda.synchronize() + print(f"Done cuda sync rank {rank}") + dist.barrier() + print(f"Done barrier rank {rank}") # # Generate plots @@ -258,7 +262,9 @@ def main(kwargs_dict: dict = {}): standard_viz.main(config) end_code_region("generate_figures") - dist.barrier() - dist.destroy_process_group() + if os.getenv("SKIP_DIST_BARRIERS") != "1": + dist.barrier() + print(f"Done barrier rank {rank}") + dist.destroy_process_group() return 0 From 370c20e9c7a6b1aa9cbf212c3a1094032984b346 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Fri, 23 Jan 2026 16:39:20 -0800 Subject: [PATCH 4/7] Enable configuring n_categories --- ScaFFold/cli.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 82532fe..a360c38 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -78,6 +78,11 @@ def main(): type=int, help="Determines dataset resolution and number of UNet layers.", ) + generate_fractals_parser.add_argument( + "--n-categories", + type=int, + help="Number of fractal categories present in the dataset.", + ) generate_fractals_parser.add_argument( "--fract-base-dir", type=str, @@ -114,7 +119,6 @@ def main(): benchmark_parser.add_argument( "--n-categories", type=int, - nargs="+", help="Number of fractal categories present in the dataset.", ) benchmark_parser.add_argument( From 2cc6b3f90cdc1acc68bd35ffdb5eff12d6761111 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 17:44:13 -0800 Subject: [PATCH 5/7] set checkpoint interval --- ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/trainer.py | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 9ba4bc3..f6d6a17 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -11,6 +11,7 @@ batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: 2 # DistConv param: number of shards to divide the tensor into shard_dim: 2 # DistConv param: dimension on which to shard +checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 06ee76d..1f3e3a6 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.checkpoint_interval = config_dict["checkpoint_interval"] class RunConfig(Config): diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4196c3a..c057452 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -581,8 +581,9 @@ def train(self): # begin_code_region("checkpoint") - extras = {"train_mask_values": self.train_set.mask_values} - self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) + if epoch % self.config.checkpoint_interval == 0: + extras = {"train_mask_values": self.train_set.mask_values} + self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") From 75a1b87dc11d9bb42f9a2d7ecafc12cb1bd2ee56 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Feb 2026 18:35:14 -0800 Subject: [PATCH 6/7] cleanup --- ScaFFold/utils/perf_measure.py | 4 ++-- ScaFFold/utils/trainer.py | 6 +----- ScaFFold/worker.py | 13 ++----------- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index 8fdec7f..5af8d5b 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -27,9 +27,9 @@ from pycaliper.instrumentation import begin_region, end_region _CALI_PERF_ENABLED = True - except Exception: - raise + except Exception as e: print("User requested Caliper annotations, but could not import Caliper") + print(f"Exception: {e}") elif ( TORCH_PERF_ENV_VAR in os.environ and os.environ.get(TORCH_PERF_ENV_VAR).lower() != "off" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index e1706b2..691a774 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -595,6 +595,7 @@ def train(self): # begin_code_region("checkpoint") + # Checkpoint only if at a checkpoint_interval epoch if epoch % self.config.checkpoint_interval == 0: extras = {"train_mask_values": self.train_set.mask_values} self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) @@ -605,8 +606,3 @@ def train(self): epoch += 1 adiak_value("final_epochs", epoch) - # if val_score >= 0.95: - # self.log.info( - # f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." - # ) - # return 0 diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index ff4ac8a..431c2eb 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -246,13 +246,6 @@ def main(kwargs_dict: dict = {}): Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." ) - # solve hang? - if os.getenv("SKIP_DIST_BARRIERS") != "1": - torch.cuda.synchronize() - print(f"Done cuda sync rank {rank}") - dist.barrier() - print(f"Done barrier rank {rank}") - # # Generate plots # @@ -262,9 +255,7 @@ def main(kwargs_dict: dict = {}): standard_viz.main(config) end_code_region("generate_figures") - if os.getenv("SKIP_DIST_BARRIERS") != "1": - dist.barrier() - print(f"Done barrier rank {rank}") - dist.destroy_process_group() + dist.barrier() + dist.destroy_process_group() return 0 From 07c295c757f69eefafcb2740352938356f2b33e4 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Feb 2026 18:39:22 -0800 Subject: [PATCH 7/7] lint --- ScaFFold/utils/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 691a774..108f13e 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -341,7 +341,6 @@ def train(self): end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") - epoch = 1 dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: