From b4c84b1e1afc373062bdef16e25c67d403f12b58 Mon Sep 17 00:00:00 2001 From: jacenfox Date: Fri, 17 May 2019 10:35:44 -0400 Subject: [PATCH 1/2] Gradient Record and Scale --- pytorch_toolbox/probe/gradient_probe.py | 71 +++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 pytorch_toolbox/probe/gradient_probe.py diff --git a/pytorch_toolbox/probe/gradient_probe.py b/pytorch_toolbox/probe/gradient_probe.py new file mode 100644 index 0000000..368a42d --- /dev/null +++ b/pytorch_toolbox/probe/gradient_probe.py @@ -0,0 +1,71 @@ +''' +Simple way to monitor/manipulate gradient. + +Usage: + gradient_record = GradientRecordHook(name='record') + gradient_scale = GradientScale(name='scale') + def net.forward(input): # your forward function + fc = base_layers(input) + # Check gradient: + fc = gradient_record(fc) + # invert gradient (e.g. domain adapt.) + fc = gradient_scale(fc, -1) + output = estimator_layers(fc) + Afterwards, we can plot these records to check the gradient. + +''' +import torch + + +class GradientScaleFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, lambda_): + ctx.lambda_ = lambda_ + return x.clone() + + @staticmethod + def backward(ctx, grads): + # for debug + # print('[min, mean, max]=[%.15f %.15f %.15f]' % (grads.min(), grads.mean(), grads.max())) + lambda_ = ctx.lambda_ + lambda_ = grads.new_tensor(lambda_) + dx = lambda_ * grads + return dx, None + + +class GradientScale(torch.nn.Module): + def __init__(self, name=None): + super(GradientScale, self).__init__() + self.name = name + self.lambdar = 0 + + def forward(self, x, lambdar): + self.lambdar = lambdar + return GradientScaleFunction.apply(x, lambdar) + + +class GradientRecordHook(torch.nn.Module): + ''' + Simple way to record gradient + ''' + + def __init__(self, name=None): + super(GradientRecordHook, self).__init__() + self.lambdar = 0 + self.gradients = [] + self.mag = None + self.std = None + self.name = name + + def hook_fun(self, grad): + self.mag = torch.mean(torch.abs(grad)).item() + self.std = torch.std(grad).item() + + def forward(self, x): + ''' + Do Nothing + ''' + if self.training: + x.register_hook(self.hook_fun) + return x + From 82ad0da425466ef9224b7a149367f944dde673cf Mon Sep 17 00:00:00 2001 From: jacenfox Date: Fri, 17 May 2019 10:37:06 -0400 Subject: [PATCH 2/2] gradient record and scale --- pytorch_toolbox/probe/gradient_probe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_toolbox/probe/gradient_probe.py b/pytorch_toolbox/probe/gradient_probe.py index 368a42d..4c0ce00 100644 --- a/pytorch_toolbox/probe/gradient_probe.py +++ b/pytorch_toolbox/probe/gradient_probe.py @@ -12,7 +12,6 @@ def net.forward(input): # your forward function fc = gradient_scale(fc, -1) output = estimator_layers(fc) Afterwards, we can plot these records to check the gradient. - ''' import torch @@ -25,8 +24,6 @@ def forward(ctx, x, lambda_): @staticmethod def backward(ctx, grads): - # for debug - # print('[min, mean, max]=[%.15f %.15f %.15f]' % (grads.min(), grads.mean(), grads.max())) lambda_ = ctx.lambda_ lambda_ = grads.new_tensor(lambda_) dx = lambda_ * grads