Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def main():
type=int,
help="Determines dataset resolution and number of UNet layers.",
)
generate_fractals_parser.add_argument(
"--n-categories",
type=int,
help="Number of fractal categories present in the dataset.",
)
generate_fractals_parser.add_argument(
"--fract-base-dir",
type=str,
Expand Down Expand Up @@ -114,7 +119,6 @@ def main():
benchmark_parser.add_argument(
"--n-categories",
type=int,
nargs="+",
help="Number of fractal categories present in the dataset.",
)
benchmark_parser.add_argument(
Expand Down
7 changes: 4 additions & 3 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into
num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: 2 # DistConv param: dimension on which to shard
checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 2000 # Number of training epochs.
epochs: -1 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
Expand All @@ -30,4 +30,5 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_epochs: 1 # How many warmup epochs before training
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
1 change: 1 addition & 0 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, config_dict):
self.dataset_reuse_enforce_commit_id = config_dict[
"dataset_reuse_enforce_commit_id"
]
self.target_dice = config_dict["target_dice"]
self.checkpoint_interval = config_dict["checkpoint_interval"]


Expand Down
3 changes: 2 additions & 1 deletion ScaFFold/utils/perf_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from pycaliper.instrumentation import begin_region, end_region

_CALI_PERF_ENABLED = True
except Exception:
except Exception as e:
print("User requested Caliper annotations, but could not import Caliper")
print(f"Exception: {e}")
elif (
TORCH_PERF_ENV_VAR in os.environ
and os.environ.get(TORCH_PERF_ENV_VAR).lower() != "off"
Expand Down
29 changes: 21 additions & 8 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# Local
from ScaFFold.utils.evaluate import evaluate
from ScaFFold.utils.perf_measure import begin_code_region, end_code_region
from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region
from ScaFFold.utils.utils import gather_and_print_mem


Expand Down Expand Up @@ -404,9 +404,17 @@ def train(self):
end_code_region("warmup")
self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s")

epoch = 1
dice_score_train = 0
with open(self.outfile_path, "a", newline="") as outfile:
start = time.time()
for epoch in range(self.start_epoch, self.config.epochs + 1):
while dice_score_train < self.config.target_dice:
if self.config.epochs != -1 and epoch > self.config.epochs:
print(
f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)."
)
break

# DistConv ParallelStrategy
ps = getattr(self.config, "_parallel_strategy", None)
if ps is None:
Expand All @@ -427,10 +435,15 @@ def train(self):
self.val_loader.sampler.set_epoch(epoch)
self.model.train()

estr = (
f"{epoch}"
if self.config.epochs == -1
else f"{epoch}/{self.config.epochs}"
)
with tqdm(
total=self.n_train // self.world_size,
desc=f"({os.path.basename(self.config.run_dir)}) \
Epoch {epoch}/{self.config.epochs}",
Epoch {estr}",
unit="img",
disable=True if self.world_rank != 0 else False,
) as pbar:
Expand Down Expand Up @@ -644,14 +657,14 @@ def train(self):
#
begin_code_region("checkpoint")

# Checkpoint only if at a checkpoint_interval epoch
if epoch % self.config.checkpoint_interval == 0:
extras = {"train_mask_values": self.train_set.mask_values}
self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras)

end_code_region("checkpoint")

if val_score >= 0.95:
self.log.info(
f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..."
)
return 0
dice_score_train = dice_sum
epoch += 1

adiak_value("final_epochs", epoch)