From bf110dece7bfc1e03cbd69cd85ffc4e330e33d87 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 30 Apr 2025 16:08:46 -0700 Subject: [PATCH] Copy gradient modifier usage from Adversary to LitModular for training universal perturbations. --- mart/models/modular.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mart/models/modular.py b/mart/models/modular.py index afab122e..0c8ebb2c 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -25,6 +25,7 @@ def __init__( self, modules, optimizer, + gradient_modifier=None, lr_scheduler=None, training_sequence=None, training_step_log=None, @@ -71,6 +72,7 @@ def __init__( # Set bias_decay and norm_decay to 0. self.optimizer_fn = OptimizerFactory(self.optimizer_fn) + self.gradient_modifier = gradient_modifier self.lr_scheduler = lr_scheduler # Be backwards compatible by turning list into dict where each item is its own key-value @@ -112,6 +114,16 @@ def configure_optimizers(self): return configure_optimizers(self.model, self.optimizer_fn, self.lr_scheduler) + def configure_gradient_clipping( + self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None + ): + # Configuring gradient clipping in pl.Trainer is still useful, so use it. + super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm) + + if self.gradient_modifier: + for group in optimizer.param_groups: + self.gradient_modifier(group["params"]) + def forward(self, **kwargs): return self.model(**kwargs)