diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 82532fe..a360c38 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -78,6 +78,11 @@ def main(): type=int, help="Determines dataset resolution and number of UNet layers.", ) + generate_fractals_parser.add_argument( + "--n-categories", + type=int, + help="Number of fractal categories present in the dataset.", + ) generate_fractals_parser.add_argument( "--fract-base-dir", type=str, @@ -114,7 +119,6 @@ def main(): benchmark_parser.add_argument( "--n-categories", type=int, - nargs="+", help="Number of fractal categories present in the dataset.", ) benchmark_parser.add_argument( diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index f6d6a17..fce1042 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,7 +9,7 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into +num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: 2 # DistConv param: dimension on which to shard checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. @@ -17,7 +17,7 @@ checkpoint_interval: 10 # Checkpoint every C epochs. More frequent ch variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. -epochs: 2000 # Number of training epochs. +epochs: -1 # Number of training epochs. learning_rate: .0001 # Learning rate for training. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. @@ -30,4 +30,5 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_epochs: 1 # How many warmup epochs before training -dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. \ No newline at end of file +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 \ No newline at end of file diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 1f3e3a6..08cb481 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index e0ac1a2..5af8d5b 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -27,8 +27,9 @@ from pycaliper.instrumentation import begin_region, end_region _CALI_PERF_ENABLED = True - except Exception: + except Exception as e: print("User requested Caliper annotations, but could not import Caliper") + print(f"Exception: {e}") elif ( TORCH_PERF_ENV_VAR in os.environ and os.environ.get(TORCH_PERF_ENV_VAR).lower() != "off" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index ac92908..d74be08 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -40,7 +40,7 @@ # Local from ScaFFold.utils.evaluate import evaluate -from ScaFFold.utils.perf_measure import begin_code_region, end_code_region +from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -404,9 +404,17 @@ def train(self): end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + epoch = 1 + dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: start = time.time() - for epoch in range(self.start_epoch, self.config.epochs + 1): + while dice_score_train < self.config.target_dice: + if self.config.epochs != -1 and epoch > self.config.epochs: + print( + f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)." + ) + break + # DistConv ParallelStrategy ps = getattr(self.config, "_parallel_strategy", None) if ps is None: @@ -427,10 +435,15 @@ def train(self): self.val_loader.sampler.set_epoch(epoch) self.model.train() + estr = ( + f"{epoch}" + if self.config.epochs == -1 + else f"{epoch}/{self.config.epochs}" + ) with tqdm( total=self.n_train // self.world_size, desc=f"({os.path.basename(self.config.run_dir)}) \ - Epoch {epoch}/{self.config.epochs}", + Epoch {estr}", unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: @@ -644,14 +657,14 @@ def train(self): # begin_code_region("checkpoint") + # Checkpoint only if at a checkpoint_interval epoch if epoch % self.config.checkpoint_interval == 0: extras = {"train_mask_values": self.train_set.mask_values} self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") - if val_score >= 0.95: - self.log.info( - f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." - ) - return 0 + dice_score_train = dice_sum + epoch += 1 + + adiak_value("final_epochs", epoch)