diff --git a/foundations/hparams.py b/foundations/hparams.py index e513c533..6014cc95 100644 --- a/foundations/hparams.py +++ b/foundations/hparams.py @@ -160,6 +160,7 @@ class TrainingHparams(Hparams): warmup_steps: str = None weight_decay: float = None apex_fp16: bool = False + save_every_n_epochs: int = None _name: str = 'Training Hyperparameters' _description: str = 'Hyperparameters that determine how the model is trained.' @@ -175,6 +176,7 @@ class TrainingHparams(Hparams): _warmup_steps: str = "Steps of linear lr warmup at the start of training as epochs ('20ep') or iterations ('800it')" _weight_decay: str = 'The L2 penalty to apply to the weights.' _apex_fp16: bool = 'Whether to train the model in float16 using the NVIDIA Apex library.' + _save_every_n_epochs: int = 'Save weights epN_it0 every time epoch N is divisible by this value (default None)' @dataclass diff --git a/training/runner.py b/training/runner.py index 7698f747..5ee69456 100644 --- a/training/runner.py +++ b/training/runner.py @@ -20,6 +20,7 @@ class TrainingRunner(Runner): desc: TrainingDesc verbose: bool = True evaluate_every_epoch: bool = True + save_every_n_epochs: int = None @staticmethod def description(): @@ -33,7 +34,7 @@ def add_args(parser: argparse.ArgumentParser) -> None: @staticmethod def create_from_args(args: argparse.Namespace) -> 'TrainingRunner': return TrainingRunner(args.replicate, TrainingDesc.create_from_args(args), - not args.quiet, not args.evaluate_only_at_end) + not args.quiet, not args.evaluate_only_at_end, args.save_every_n_epochs) def display_output_location(self): print(self.desc.run_path(self.replicate)) @@ -46,4 +47,5 @@ def run(self): self.desc.save(self.desc.run_path(self.replicate)) train.standard_train( models.registry.get(self.desc.model_hparams), self.desc.run_path(self.replicate), - self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch) + self.desc.dataset_hparams, self.desc.training_hparams, evaluate_every_epoch=self.evaluate_every_epoch, + save_every_n_epochs=self.save_every_n_epochs) diff --git a/training/standard_callbacks.py b/training/standard_callbacks.py index 3657d6b5..4633d37e 100644 --- a/training/standard_callbacks.py +++ b/training/standard_callbacks.py @@ -110,10 +110,19 @@ def modified_callback(output_location, step, model, optimizer, logger): return modified_callback +def run_every_n_epochs(callback, n, offset=0): + def modified_callback(output_location, step, model, optimizer, logger): + if step.it != 0 or (step.ep - offset) % n != 0: + return + callback(output_location, step, model, optimizer, logger) + return modified_callback + + # The standard set of callbacks that should be used for a normal training run. def standard_callbacks(training_hparams: hparams.TrainingHparams, train_set_loader: DataLoader, test_set_loader: DataLoader, eval_on_train: bool = False, verbose: bool = True, - start_step: Step = None, evaluate_every_epoch: bool = True): + start_step: Step = None, evaluate_every_epoch: bool = True, + save_every_n_epochs: int = None): start = start_step or Step.zero(train_set_loader.iterations_per_epoch) end = Step.from_str(training_hparams.training_steps, train_set_loader.iterations_per_epoch) test_eval_callback = create_eval_callback('test', test_set_loader, verbose=verbose) @@ -131,6 +140,9 @@ def standard_callbacks(training_hparams: hparams.TrainingHparams, train_set_load if evaluate_every_epoch: result = [run_every_epoch(test_eval_callback)] + result elif verbose: result.append(run_every_epoch(create_timekeeper_callback())) + # Save model weights every N epochs if requested + if save_every_n_epochs is not None: result.append(run_every_n_epochs(save_model, n=save_every_n_epochs, offset=0)) + # Ensure that testing occurs at least at the beginning and end of training. if start.it != 0 or not evaluate_every_epoch: result = [run_at_step(start, test_eval_callback)] + result if end.it != 0 or not evaluate_every_epoch: result = [run_at_step(end, test_eval_callback)] + result diff --git a/training/train.py b/training/train.py index 99951b46..5b92ffda 100644 --- a/training/train.py +++ b/training/train.py @@ -138,7 +138,8 @@ def standard_train( training_hparams: hparams.TrainingHparams, start_step: Step = None, verbose: bool = True, - evaluate_every_epoch: bool = True + evaluate_every_epoch: bool = True, + save_every_n_epochs=None, ): """Train using the standard callbacks according to the provided hparams.""" @@ -152,5 +153,6 @@ def standard_train( test_loader = datasets.registry.get(dataset_hparams, train=False) callbacks = standard_callbacks.standard_callbacks( training_hparams, train_loader, test_loader, start_step=start_step, - verbose=verbose, evaluate_every_epoch=evaluate_every_epoch) + verbose=verbose, evaluate_every_epoch=evaluate_every_epoch, + save_every_n_epochs=save_every_n_epochs) train(training_hparams, model, train_loader, output_location, callbacks, start_step=start_step)