From cf7c1c799f83501e4a1e29027d74044ae3d8e1c7 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 13 Jun 2021 22:38:20 -0400 Subject: [PATCH 1/6] Add LayerNorm for FisherBlockEff extension. --- .../firstorder/fisher_block_eff/layernorm.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 backpack/extensions/firstorder/fisher_block_eff/layernorm.py diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py new file mode 100644 index 000000000..e34e9f8e1 --- /dev/null +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -0,0 +1,54 @@ +import torch + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.extensions.firstorder.fisher_block_eff.fisher_block_eff_base import FisherBlockEffBase + +class FisherBlockEffLayerNorm(FisherBlockEffBase): + def __init__(self, damping=1.0): + self.damping = damping + super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"]) + + def weight(self, ext, module, g_inp, g_out, backproped): + I = module.input0 + assert(len(I.shape) in [2, 4]) # linear or conv + n, c, h, w = I.shape[0], I.shape[1], 1, 1 # input shape = output shape for LayerNorm + g_out_sc = n * g_out[0] + + if len(I.shape) == 4: # conv + h, w = I.shape[2], I.shape[3] + # flatten: [n, c, h * w] + I = I.reshape(n, c, -1) + g_out_sc = g_out_sc.reshape(n, c, -1) + + G = g_out_sc + + grad = module.weight.grad.reshape(-1) + + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + + x_hat = (I - mean) / (var + module.eps).sqrt() + + J = g_out_sc * x_hat + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / self.damping + update = update.reshape(module.weight.grad.shape) + + module.I = I + module.G = G + module.NGD_inv = NGD_inv + + return update + + def bias(self, ext, module, g_inp, g_out, backproped): + return module.bias.grad From 4b4981a90276069fb808b520857c75ff4bdfcbbd Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 13 Jun 2021 22:39:11 -0400 Subject: [PATCH 2/6] Add LayerNorm in backpack. --- backpack/.DS_Store | Bin 8196 -> 8196 bytes backpack/__init__.py | 4 ++-- backpack/core/derivatives/__init__.py | 3 +++ backpack/extensions/.DS_Store | Bin 8196 -> 8196 bytes backpack/extensions/firstorder/.DS_Store | Bin 10244 -> 10244 bytes .../firstorder/fisher_block_eff/.DS_Store | Bin 6148 -> 6148 bytes .../firstorder/fisher_block_eff/__init__.py | 7 +++++-- 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/backpack/.DS_Store b/backpack/.DS_Store index eb0087a9fdec1809dc1374b4c0750afa3ab11df9..0054067bf28f4a98fe737cede6d400a430cb2ef7 100644 GIT binary patch delta 43 zcmZp1XmOa}&&a)2{AQV)EjDrUNl`ndi4Dd8JeLmM delta 77 zcmZp1XmOa}&nU1lU^hRbz+@f)OKon3cpxlbsANcHNCe^xhE$M)2d}End!MoW(^f=SR24es#4h~8H delta 77 zcmZp1XmOa}&nU1lU^hRbz+@f)OLcCBcpxlbsANcHNCe^xhE$M9CNP_QGrPn$mdS5~+&A-y9$*3h`fe0{ diff --git a/backpack/extensions/firstorder/.DS_Store b/backpack/extensions/firstorder/.DS_Store index b3fa9dd27bca00838331be8155d77aaa77c3fa2a..5e83910de8ef85fe75a1a08ec3220448950d293e 100644 GIT binary patch delta 44 zcmZn(XbG6$&nU4mU^hRb#AF@;%gxmSJ^T|J%r>(t{9>7GC?Pg^otV$&nW77s08U>H A{r~^~ delta 81 zcmZn(XbG6$&nUSuU^hRb)g{N)D#ng$8Q delta 72 zcmZoMXfc@J&&aVcU^gQp$7CKROKon3cpxlbsANcHNCe^xhE$M Date: Mon, 14 Jun 2021 04:28:55 -0400 Subject: [PATCH 3/6] Try take mean and var over surface (not volume) for 4D tensor. --- .DS_Store | Bin 10244 -> 10244 bytes .../firstorder/fisher_block_eff/layernorm.py | 5 ++++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.DS_Store b/.DS_Store index bd5233c34d70de86e47e818c6e4e695823306c36..2bab8a66a9b431a79d0e56cdb2ea349bb42c0a32 100644 GIT binary patch delta 37 rcmZn(XbG6$UDU^hRb;$|L!WY*0mMeVpIHbiY^SI7WyrZNEl??4QJ delta 210 zcmZn(XbG6$I9U^hRb(q7><{VK+uFdQU8ML&) GjtKw|(=^He diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py index e34e9f8e1..283f0300d 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -29,7 +29,10 @@ def weight(self, ext, module, g_inp, g_out, backproped): x_hat = (I - mean) / (var + module.eps).sqrt() - J = g_out_sc * x_hat + if len(I.shape) == 2: + J = g_out_sc * x_hat + else: + J = torch.einsum('ncf,ncf->nf', g_out_sc, x_hat) J = J.reshape(J.shape[0], -1) JJT = torch.matmul(J, J.t()) From e83750946b397c8010b0f4cade5d351ccbc2da14 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 16 Jun 2021 13:17:46 -0400 Subject: [PATCH 4/6] Add bias. --- .../firstorder/fisher_block_eff/layernorm.py | 108 ++++++++++++------ 1 file changed, 71 insertions(+), 37 deletions(-) diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py index 283f0300d..6ef30a996 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -4,54 +4,88 @@ from backpack.extensions.firstorder.fisher_block_eff.fisher_block_eff_base import FisherBlockEffBase class FisherBlockEffLayerNorm(FisherBlockEffBase): - def __init__(self, damping=1.0): - self.damping = damping - super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"]) + def __init__(self, damping=1.0): + self.damping = damping + super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"]) - def weight(self, ext, module, g_inp, g_out, backproped): - I = module.input0 - assert(len(I.shape) in [2, 4]) # linear or conv - n, c, h, w = I.shape[0], I.shape[1], 1, 1 # input shape = output shape for LayerNorm - g_out_sc = n * g_out[0] + def weight(self, ext, module, g_inp, g_out, backproped): + I = module.input0 + assert(len(I.shape) in [2, 4]) # linear or conv + n, c, h, w = I.shape[0], I.shape[1], 1, 1 # input shape = output shape for LayerNorm + g_out_sc = n * g_out[0] - if len(I.shape) == 4: # conv - h, w = I.shape[2], I.shape[3] - # flatten: [n, c, h * w] - I = I.reshape(n, c, -1) - g_out_sc = g_out_sc.reshape(n, c, -1) + if len(I.shape) == 4: # conv + h, w = I.shape[2], I.shape[3] + # flatten: [n, c, h * w] + I = I.reshape(n, c, -1) + g_out_sc = g_out_sc.reshape(n, c, -1) - G = g_out_sc + G = g_out_sc - grad = module.weight.grad.reshape(-1) + grad = module.weight.grad.reshape(-1) - mean = I.mean(dim=-1).unsqueeze(-1) - var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + if len(I.shape) == 2: + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + else: + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) - x_hat = (I - mean) / (var + module.eps).sqrt() + x_hat = (I - mean) / (var + module.eps).sqrt() - if len(I.shape) == 2: - J = g_out_sc * x_hat - else: - J = torch.einsum('ncf,ncf->nf', g_out_sc, x_hat) - J = J.reshape(J.shape[0], -1) - JJT = torch.matmul(J, J.t()) + J = g_out_sc * x_hat - grad_prod = torch.matmul(J, grad) + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) - NGD_kernel = JJT / n - NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) - v = torch.matmul(NGD_inv, grad_prod) + grad_prod = torch.matmul(J, grad) - gv = torch.matmul(J.t(), v) / n + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) - update = (grad - gv) / self.damping - update = update.reshape(module.weight.grad.shape) + gv = torch.matmul(J.t(), v) / n - module.I = I - module.G = G - module.NGD_inv = NGD_inv + update = (grad - gv) / self.damping + update = update.reshape(module.weight.grad.shape) - return update + module.I = I + module.G = G + module.NGD_inv = NGD_inv - def bias(self, ext, module, g_inp, g_out, backproped): - return module.bias.grad + return update + + def bias(self, ext, module, g_inp, g_out, backproped): + I = module.input0 + assert(len(I.shape) in [2, 4]) + n, c, h, w = I.shape[0], I.shape[1], 1, 1 + g_out_sc = n * g_out[0] + + if len(I.shape) == 4: + h, w = I.shape[2], I.shape[3] + I = I.reshape(n, c, -1) + g_out_sc = g_out_sc.reshape(n, c, -1) + + G = g_out_sc + + grad = module.bias.grad.reshape(-1) + + J = g_out_sc + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / self.damping + update = update.reshape(module.bias.grad.shape) + + module.NGD_inv = NGD_inv + module.G = G + + return update \ No newline at end of file From 5be4d6b6993667a4e7426a7f6ad2b7b8b99fa6cc Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 17 Jun 2021 08:28:04 -0400 Subject: [PATCH 5/6] Comment [h, w] case. --- .../firstorder/fisher_block_eff/layernorm.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py index 6ef30a996..304e73b24 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -24,6 +24,7 @@ def weight(self, ext, module, g_inp, g_out, backproped): grad = module.weight.grad.reshape(-1) + # compute mean, var over [c, h, w] if len(I.shape) == 2: mean = I.mean(dim=-1).unsqueeze(-1) var = I.var(dim=-1, unbiased=False).unsqueeze(-1) @@ -31,10 +32,19 @@ def weight(self, ext, module, g_inp, g_out, backproped): mean = I.mean((-2, -1), keepdims=True) var = I.var((-2, -1), unbiased=False, keepdims=True) + # compute mean, var over [h, w] + # mean = I.mean(dim=-1).unsqueeze(-1) + # var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + x_hat = (I - mean) / (var + module.eps).sqrt() J = g_out_sc * x_hat + # if len(I.shape) == 2: + # J = g_out_sc * x_hat + # else: + # J = torch.einsum('ncf,ncf->nf', g_out_sc, x_hat) + J = J.reshape(J.shape[0], -1) JJT = torch.matmul(J, J.t()) @@ -71,6 +81,12 @@ def bias(self, ext, module, g_inp, g_out, backproped): grad = module.bias.grad.reshape(-1) J = g_out_sc + + # if len(I.shape) == 2: + # J = g_out_sc + # else: + # J = torch.einsum('ncf->nf', g_out_sc) + J = J.reshape(J.shape[0], -1) JJT = torch.matmul(J, J.t()) From 076afbbe06d676f7fa93907757feb8262eed2d37 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Fri, 18 Jun 2021 23:30:53 -0400 Subject: [PATCH 6/6] Add save NGD_kernel option. --- .../extensions/firstorder/fisher_block_eff/__init__.py | 9 +++++---- .../extensions/firstorder/fisher_block_eff/conv2d.py | 6 ++++-- .../extensions/firstorder/fisher_block_eff/layernorm.py | 7 ++++++- .../extensions/firstorder/fisher_block_eff/linear.py | 8 +++++++- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/backpack/extensions/firstorder/fisher_block_eff/__init__.py b/backpack/extensions/firstorder/fisher_block_eff/__init__.py index 572270d10..37b5ef48c 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/__init__.py +++ b/backpack/extensions/firstorder/fisher_block_eff/__init__.py @@ -22,22 +22,23 @@ class FisherBlockEff(BackpropExtension): - def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): + def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): self.gamma = gamma self.damping = damping self.alpha =alpha self.low_rank = low_rank self.memory_efficient = memory_efficient self.super_opt = super_opt + self.save_kernel = save_kernel super().__init__( savefield="fisher_block", fail_mode="WARNING", module_exts={ - Linear: linear.FisherBlockEffLinear(self.damping, self.alpha), + Linear: linear.FisherBlockEffLinear(self.damping, self.alpha, self.save_kernel), Conv1d: conv1d.FisherBlockEffConv1d(self.damping), - Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt), + Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt, self.save_kernel), BatchNorm1d: batchnorm1d.FisherBlockEffBatchNorm1d(self.damping), BatchNorm2d: batchnorm2d.FisherBlockEffBatchNorm2d(self.damping), - LayerNorm: layernorm.FisherBlockEffLayerNorm(self.damping) + LayerNorm: layernorm.FisherBlockEffLayerNorm(self.damping, self.save_kernel) }, ) diff --git a/backpack/extensions/firstorder/fisher_block_eff/conv2d.py b/backpack/extensions/firstorder/fisher_block_eff/conv2d.py index 07b29db8c..9bf3b742e 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/conv2d.py +++ b/backpack/extensions/firstorder/fisher_block_eff/conv2d.py @@ -14,12 +14,13 @@ # import matplotlib.pylab as plt MODE = 0 class FisherBlockEffConv2d(FisherBlockEffBase): - def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): + def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): self.damping = damping self.low_rank = low_rank self.gamma = gamma self.memory_efficient = memory_efficient self.super_opt = super_opt + self.save_kernel = save_kernel super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, bpQuantities): @@ -106,7 +107,8 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): gv = einsum("nkm,n->mk", (AX, v)).view_as(grad) /n module.AX = AX - + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel update = (grad - gv)/self.damping return update diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py index 304e73b24..a7aa0d654 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -4,8 +4,9 @@ from backpack.extensions.firstorder.fisher_block_eff.fisher_block_eff_base import FisherBlockEffBase class FisherBlockEffLayerNorm(FisherBlockEffBase): - def __init__(self, damping=1.0): + def __init__(self, damping=1.0, save_kernel='false'): self.damping = damping + self.save_kernel = save_kernel super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): @@ -62,6 +63,8 @@ def weight(self, ext, module, g_inp, g_out, backproped): module.I = I module.G = G module.NGD_inv = NGD_inv + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel return update @@ -103,5 +106,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): module.NGD_inv = NGD_inv module.G = G + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel return update \ No newline at end of file diff --git a/backpack/extensions/firstorder/fisher_block_eff/linear.py b/backpack/extensions/firstorder/fisher_block_eff/linear.py index c5164a451..92ab4ef42 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/linear.py +++ b/backpack/extensions/firstorder/fisher_block_eff/linear.py @@ -6,9 +6,10 @@ class FisherBlockEffLinear(FisherBlockEffBase): - def __init__(self, damping=1.0, alpha=0.95): + def __init__(self, damping=1.0, alpha=0.95, save_kernel='false'): self.damping = damping self.alpha = alpha + self.save_kernel = save_kernel super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): @@ -42,6 +43,8 @@ def weight(self, ext, module, g_inp, g_out, backproped): module.I = I module.G = G module.NGD_inv = NGD_inv + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel return update @@ -68,6 +71,9 @@ def bias(self, ext, module, g_inp, g_out, backproped): update = (grad - gv)/self.damping # update = grad + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel + return update