From a8ad0da69abb98dbefc1beea22d013ecf976be73 Mon Sep 17 00:00:00 2001 From: Enrique Manjavacas Date: Mon, 7 Dec 2020 14:22:01 +0100 Subject: [PATCH 1/4] adapt --- pie/default_settings.json | 14 ++++++++++---- pie/settings.py | 6 ++++++ pie/trainer.py | 27 ++++++++++++++++----------- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/pie/default_settings.json b/pie/default_settings.json index cc5279f..68d3e2a 100644 --- a/pie/default_settings.json +++ b/pie/default_settings.json @@ -124,10 +124,16 @@ "optimizer": "Adam", // optimizer type "clip_norm": 5.0, // clip norm of gradient up to this value "lr": 0.001, - "lr_factor": 0.75, // lr schedule (decrease lr by this factor after `lr_patience` epochs - // without improvement on dev-set data) - "min_lr": 0.000001, // minimum learning rate - "lr_patience": 2, // patience for lr schedule + + "lr_scheduler": "ReduceLROnPlateau", + "lr_scheduler_params": { + // needs to be adapted if lr_scheduler is not ReduceLROnPlateau + "mode": "max", + "factor": 0.75, + "patience": 2, + "min_lr": 0.000001 + }, + "checks_per_epoch": 1, // check model on dev-set so many times during epoch // * Model hyperparameters diff --git a/pie/settings.py b/pie/settings.py index d2a266a..969d584 100644 --- a/pie/settings.py +++ b/pie/settings.py @@ -106,6 +106,12 @@ def check_settings(settings): if not has_target: raise ValueError("Needs at least one target task") + # backward compatibility + if "lr_patience" in settings: + settings.lr_scheduler_params['patience'] = settings.lr_patience + if "lr_factor" in settings: + settings.lr_scheduler_params['factor'] = settings.lr_factor + return settings diff --git a/pie/trainer.py b/pie/trainer.py index 2dceb7e..cf8e122 100644 --- a/pie/trainer.py +++ b/pie/trainer.py @@ -150,19 +150,25 @@ def get_weights(self): class LRScheduler(object): - def __init__(self, optimizer, threshold=0.0, **kwargs): - self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode='max', threshold=threshold, **kwargs) + def __init__(self, optimizer, lr_scheduler='ReduceLROnPlateau', **kwargs) + self.lr_scheduler = getattr(optim.lr_scheduler, lr_scheduler)( + optimizer, **kwargs) def step(self, score): - self.lr_scheduler.step(score) + if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(score) + else: + slef.lr_scheduler.step() def __repr__(self): - return '' \ - .format(self.lr_scheduler.optimizer.param_groups[0]['lr'], - self.lr_scheduler.num_bad_epochs, - self.lr_scheduler.patience, - self.lr_scheduler.threshold) + res = ' Date: Mon, 7 Dec 2020 14:56:15 +0100 Subject: [PATCH 2/4] add lrscheduler name to string --- pie/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pie/trainer.py b/pie/trainer.py index cf8e122..f1ab4e0 100644 --- a/pie/trainer.py +++ b/pie/trainer.py @@ -161,7 +161,8 @@ def step(self, score): slef.lr_scheduler.step() def __repr__(self): - res = ' Date: Mon, 7 Dec 2020 17:15:52 +0100 Subject: [PATCH 3/4] Feature/allow picking lr scheduler (#80) * Fix code typo in trainer.py * (Trainer) Implement Delay on LRScheduler --- pie/default_settings.json | 1 + pie/trainer.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/pie/default_settings.json b/pie/default_settings.json index 68d3e2a..da1a848 100644 --- a/pie/default_settings.json +++ b/pie/default_settings.json @@ -126,6 +126,7 @@ "lr": 0.001, "lr_scheduler": "ReduceLROnPlateau", + "lr_scheduler_delay": 0, // Number of steps without using the lr_scheduler "lr_scheduler_params": { // needs to be adapted if lr_scheduler is not ReduceLROnPlateau "mode": "max", diff --git a/pie/trainer.py b/pie/trainer.py index f1ab4e0..0921109 100644 --- a/pie/trainer.py +++ b/pie/trainer.py @@ -150,20 +150,30 @@ def get_weights(self): class LRScheduler(object): - def __init__(self, optimizer, lr_scheduler='ReduceLROnPlateau', **kwargs) + def __init__(self, optimizer, lr_scheduler='ReduceLROnPlateau', delay=0, **kwargs): + self.nb_steps: int = 0 + self.delay: int = delay self.lr_scheduler = getattr(optim.lr_scheduler, lr_scheduler)( optimizer, **kwargs) def step(self, score): + self.nb_steps += 1 + + # If we have a delay, we do not apply the step() method of the lr_scheduler until it's reached + if self.delay and self.nb_steps <= self.delay: + return + if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step(score) else: - slef.lr_scheduler.step() + self.lr_scheduler.step() def __repr__(self): - res = ' Date: Mon, 7 Dec 2020 17:20:44 +0100 Subject: [PATCH 4/4] fix --- pie/trainer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pie/trainer.py b/pie/trainer.py index 0921109..4dffcf3 100644 --- a/pie/trainer.py +++ b/pie/trainer.py @@ -1,4 +1,5 @@ +import inspect import os import uuid import logging @@ -7,7 +8,6 @@ import random import tempfile - import tqdm import torch @@ -150,16 +150,23 @@ def get_weights(self): class LRScheduler(object): - def __init__(self, optimizer, lr_scheduler='ReduceLROnPlateau', delay=0, **kwargs): + def __init__(self, optimizer, + lr_scheduler='ReduceLROnPlateau', + delay=0, **kwargs): + self.nb_steps: int = 0 self.delay: int = delay - self.lr_scheduler = getattr(optim.lr_scheduler, lr_scheduler)( - optimizer, **kwargs) + lr_scheduler_cls = getattr(optim.lr_scheduler, lr_scheduler) + params = inspect.signature(lr_scheduler_cls).parameters + self.lr_scheduler = lr_scheduler_cls( + optimizer, + # pick kwargs that fit the selected scheduler + **{kwarg: val for kwarg, val in kwargs.items() if kwarg in params}) def step(self, score): self.nb_steps += 1 - # If we have a delay, we do not apply the step() method of the lr_scheduler until it's reached + # apply the step() method of the lr_scheduler when delay is reached if self.delay and self.nb_steps <= self.delay: return @@ -220,8 +227,7 @@ def __init__(self, settings, model, dataset, num_instances): self.optimizer, lr_scheduler=settings.lr_scheduler, delay=settings.lr_scheduler_delay, - **settings.lr_scheduler_params - ) + **settings.lr_scheduler_params) if settings.verbose: print()