Skip to content
This repository was archived by the owner on May 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions foundations/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class TrainingHparams(Hparams):
warmup_steps: str = None
weight_decay: float = None
apex_fp16: bool = False
save_every_n_epochs: int = None

_name: str = 'Training Hyperparameters'
_description: str = 'Hyperparameters that determine how the model is trained.'
Expand All @@ -175,6 +176,7 @@ class TrainingHparams(Hparams):
_warmup_steps: str = "Steps of linear lr warmup at the start of training as epochs ('20ep') or iterations ('800it')"
_weight_decay: str = 'The L2 penalty to apply to the weights.'
_apex_fp16: bool = 'Whether to train the model in float16 using the NVIDIA Apex library.'
_save_every_n_epochs: int = 'Save weights epN_it0 every time epoch N is divisible by this value (default None)'


@dataclass
Expand Down
6 changes: 4 additions & 2 deletions training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TrainingRunner(Runner):
desc: TrainingDesc
verbose: bool = True
evaluate_every_epoch: bool = True
save_every_n_epochs: int = None

@staticmethod
def description():
Expand All @@ -33,7 +34,7 @@ def add_args(parser: argparse.ArgumentParser) -> None:
@staticmethod
def create_from_args(args: argparse.Namespace) -> 'TrainingRunner':
return TrainingRunner(args.replicate, TrainingDesc.create_from_args(args),
not args.quiet, not args.evaluate_only_at_end)
not args.quiet, not args.evaluate_only_at_end, args.save_every_n_epochs)

def display_output_location(self):
print(self.desc.run_path(self.replicate))
Expand All @@ -46,4 +47,5 @@ def run(self):
self.desc.save(self.desc.run_path(self.replicate))
train.standard_train(
models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate),
self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch)
self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch,
save_every_n_epochs=self.save_every_n_epochs)
14 changes: 13 additions & 1 deletion training/standard_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,19 @@ def modified_callback(output_location, step, model, optimizer, logger):
return modified_callback


def run_every_n_epochs(callback, n, offset=0):
def modified_callback(output_location, step, model, optimizer, logger):
if step.it != 0 or (step.ep - offset) % n != 0:
return
callback(output_location, step, model, optimizer, logger)
return modified_callback


# The standard set of callbacks that should be used for a normal training run.
def standard_callbacks(training_hparams: hparams.TrainingHparams, train_set_loader: DataLoader,
test_set_loader: DataLoader, eval_on_train: bool = False, verbose: bool = True,
start_step: Step = None, evaluate_every_epoch: bool = True):
start_step: Step = None, evaluate_every_epoch: bool = True,
save_every_n_epochs: int = None):
start = start_step or Step.zero(train_set_loader.iterations_per_epoch)
end = Step.from_str(training_hparams.training_steps, train_set_loader.iterations_per_epoch)
test_eval_callback = create_eval_callback('test', test_set_loader, verbose=verbose)
Expand All @@ -131,6 +140,9 @@ def standard_callbacks(training_hparams: hparams.TrainingHparams, train_set_load
if evaluate_every_epoch: result = [run_every_epoch(test_eval_callback)] + result
elif verbose: result.append(run_every_epoch(create_timekeeper_callback()))

# Save model weights every N epochs if requested
if save_every_n_epochs is not None: result.append(run_every_n_epochs(save_model, n=save_every_n_epochs, offset=0))

# Ensure that testing occurs at least at the beginning and end of training.
if start.it != 0 or not evaluate_every_epoch: result = [run_at_step(start, test_eval_callback)] + result
if end.it != 0 or not evaluate_every_epoch: result = [run_at_step(end, test_eval_callback)] + result
Expand Down
6 changes: 4 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def standard_train(
training_hparams: hparams.TrainingHparams,
start_step: Step = None,
verbose: bool = True,
evaluate_every_epoch: bool = True
evaluate_every_epoch: bool = True,
save_every_n_epochs=None,
):
"""Train using the standard callbacks according to the provided hparams."""

Expand All @@ -152,5 +153,6 @@ def standard_train(
test_loader = datasets.registry.get(dataset_hparams, train=False)
callbacks = standard_callbacks.standard_callbacks(
training_hparams, train_loader, test_loader, start_step=start_step,
verbose=verbose, evaluate_every_epoch=evaluate_every_epoch)
verbose=verbose, evaluate_every_epoch=evaluate_every_epoch,
save_every_n_epochs=save_every_n_epochs)
train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step)