diff --git a/pie/default_settings.json b/pie/default_settings.json index cc5279f..da1a848 100644 --- a/pie/default_settings.json +++ b/pie/default_settings.json @@ -124,10 +124,17 @@ "optimizer": "Adam", // optimizer type "clip_norm": 5.0, // clip norm of gradient up to this value "lr": 0.001, - "lr_factor": 0.75, // lr schedule (decrease lr by this factor after `lr_patience` epochs - // without improvement on dev-set data) - "min_lr": 0.000001, // minimum learning rate - "lr_patience": 2, // patience for lr schedule + + "lr_scheduler": "ReduceLROnPlateau", + "lr_scheduler_delay": 0, // Number of steps without using the lr_scheduler + "lr_scheduler_params": { + // needs to be adapted if lr_scheduler is not ReduceLROnPlateau + "mode": "max", + "factor": 0.75, + "patience": 2, + "min_lr": 0.000001 + }, + "checks_per_epoch": 1, // check model on dev-set so many times during epoch // * Model hyperparameters diff --git a/pie/settings.py b/pie/settings.py index d2a266a..969d584 100644 --- a/pie/settings.py +++ b/pie/settings.py @@ -106,6 +106,12 @@ def check_settings(settings): if not has_target: raise ValueError("Needs at least one target task") + # backward compatibility + if "lr_patience" in settings: + settings.lr_scheduler_params['patience'] = settings.lr_patience + if "lr_factor" in settings: + settings.lr_scheduler_params['factor'] = settings.lr_factor + return settings diff --git a/pie/trainer.py b/pie/trainer.py index 2dceb7e..4dffcf3 100644 --- a/pie/trainer.py +++ b/pie/trainer.py @@ -1,4 +1,5 @@ +import inspect import os import uuid import logging @@ -7,7 +8,6 @@ import random import tempfile - import tqdm import torch @@ -150,19 +150,43 @@ def get_weights(self): class LRScheduler(object): - def __init__(self, optimizer, threshold=0.0, **kwargs): - self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode='max', threshold=threshold, **kwargs) + def __init__(self, optimizer, + lr_scheduler='ReduceLROnPlateau', + delay=0, **kwargs): + + self.nb_steps: int = 0 + self.delay: int = delay + lr_scheduler_cls = getattr(optim.lr_scheduler, lr_scheduler) + params = inspect.signature(lr_scheduler_cls).parameters + self.lr_scheduler = lr_scheduler_cls( + optimizer, + # pick kwargs that fit the selected scheduler + **{kwarg: val for kwarg, val in kwargs.items() if kwarg in params}) def step(self, score): - self.lr_scheduler.step(score) + self.nb_steps += 1 + + # apply the step() method of the lr_scheduler when delay is reached + if self.delay and self.nb_steps <= self.delay: + return + + if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(score) + else: + self.lr_scheduler.step() def __repr__(self): - return '' \ - .format(self.lr_scheduler.optimizer.param_groups[0]['lr'], - self.lr_scheduler.num_bad_epochs, - self.lr_scheduler.patience, - self.lr_scheduler.threshold) + res = '