From d8a631dafa762a4175dd9b5b35e07726a8f638e4 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Tue, 24 Aug 2021 15:29:42 -0400 Subject: [PATCH 01/19] Add FusedFisherBlock extension. (cherry picked from commit 8a76e44bfe667f7c1a010b8768523fc7dab8de8f) --- backpack/extensions/__init__.py | 2 ++ backpack/extensions/secondorder/__init__.py | 2 ++ .../fused_fisher_block/__init__.py | 21 +++++++++++++++++++ .../fused_fisher_block_base.py | 6 ++++++ .../secondorder/fused_fisher_block/linear.py | 0 5 files changed, 31 insertions(+) create mode 100644 backpack/extensions/secondorder/fused_fisher_block/__init__.py create mode 100644 backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py create mode 100644 backpack/extensions/secondorder/fused_fisher_block/linear.py diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index d6549ea58..91b760de2 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -15,6 +15,7 @@ DiagGGNMC, DiagHessian, MNGD, + FusedFisherBlock ) __all__ = [ @@ -38,4 +39,5 @@ "DiagGGNMC", "DiagGGN", "DiagHessian", + "FusedFisherBlock" ] diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index 1f7f24ec9..0adabe19d 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -23,6 +23,7 @@ from .diag_hessian import DiagHessian from .hbp import HBP, KFAC, KFLR, KFRA from .mngd import MNGD +from .fused_fisher_block import FusedFisherBlock __all__ = [ "MNGD", @@ -35,4 +36,5 @@ "KFLR", "KFRA", "HBP", + "FusedFisherBlock" ] diff --git a/backpack/extensions/secondorder/fused_fisher_block/__init__.py b/backpack/extensions/secondorder/fused_fisher_block/__init__.py new file mode 100644 index 000000000..d077f5cfc --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/__init__.py @@ -0,0 +1,21 @@ +from torch.nn import ( + Linear +) + +from backpack.extensions.backprop_extension import BackpropExtension + +from . import ( + linear +) + +class FusedFisherBlock(BackpropExtension): + def __init__(self, loss_sample, damping=1.0): + self.loss_sample = loss_sample + self.damping = damping + super().__init__( + savefield="fused_fisher_block", + fail_mode="WARNING", + module_exts={ + Linear: linear.FusedFisherBlockLinear(self.damping) + }, + ) diff --git a/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py new file mode 100644 index 000000000..a2f1a5746 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py @@ -0,0 +1,6 @@ +from backpack.extensions.mat_to_mat_jac_base import MatToJacMat + + +class FusedFisherBlockBaseModule(MatToJacMat): + def __init__(self, derivatives, params=None): + super().__init__(derivatives, params=params) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py new file mode 100644 index 000000000..e69de29bb From 36c8eeb5a097b489476542610a7b7d2baea2555d Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 25 Aug 2021 14:52:52 -0400 Subject: [PATCH 02/19] init linear. --- .../secondorder/fused_fisher_block/linear.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index e69de29bb..b66997f4c 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -0,0 +1,70 @@ +from torch import einsum, eye, matmul, ones_like, norm +from torch.linalg import inv + +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockLinear(FusedFisherBlockBaseModule): + def __init__(self, damping=1.0): + self.damping = damping + self.alpha = alpha + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) + + def weight(self, ext, module, g_inp, g_out, backproped): + # TODO(bmu): manully backprop quantities in the extra b/w pass + + I = module.input0 + n = g_out[0].shape[0] + g_out_sc = n * g_out[0] + G = g_out_sc + grad = module.weight.grad + + B = einsum("ni,li->nl", (I, I)) + A = einsum("no,lo->nl", (G, G)) + + # compute vector jacobian product in optimization method + grad_prod = einsum("ni,oi->no", (I, grad)) + grad_prod = einsum("no,no->n", (grad_prod, G)) + # grad_prod = 0 + out = A * B + # out = 0 + NGD_kernel = out / n + NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) + v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + + gv = einsum("n,no->no", (v, G)) + gv = einsum("no,ni->oi", (gv, I)) + gv = gv / n + + update = (grad - gv)/self.damping + + module.I = I + module.G = G + module.NGD_inv = NGD_inv + return update + + + def bias(self, ext, module, g_inp, g_out, backproped): + # TODO(bmu): manully backprop quantities in the extra b/w pass + + grad = module.bias.grad + n = g_out[0].shape[0] + g_out_sc = n * g_out[0] + + # compute vector jacobian product in optimization method + grad_prod = einsum("no,o->n", (g_out_sc, grad)) + # grad_prod = 0 + out = einsum("no,lo->nl", g_out_sc, g_out_sc) + # out = 0 + + NGD_kernel = out / n + NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) + v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() + gv = einsum("n,no->o", (v, g_out_sc)) + gv = gv / n + + update = (grad - gv)/self.damping + + return update + From 0d0ff34fd5629d3f6dbe58e0609854bf61bccba8 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 25 Aug 2021 14:53:14 -0400 Subject: [PATCH 03/19] init cross entropy loss. --- .../secondorder/fused_fisher_block/losses.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 backpack/extensions/secondorder/fused_fisher_block/losses.py diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py new file mode 100644 index 000000000..48c756434 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -0,0 +1,22 @@ +from functools import partial + +from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockLoss(FusedFisherBlockBaseModule): + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + # backprop symmetric factorization of the hessian of loss w.r.t. outputs of the network, + # i.e. H = SS^T + hess_func = self.make_loss_hessian_func(ext) + + return hess_func(module, grad_inp, grad_out) + + def make_loss_hessian_func(self, ext): + # TODO(bmu): try both exact and MC sampling + return self.derivatives.sqrt_hessian + + +class FusedFisherBlockCrossEntropyLoss(FusedFisherBlockLoss): + def __init__(self): + super().__init__(derivatives=CrossEntropyLossDerivatives()) From 6adc87016aa94949c63f8c780c325a4af314bf67 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 25 Aug 2021 16:59:41 -0400 Subject: [PATCH 04/19] enable cross entropy loss. --- backpack/__init__.py | 4 ++-- .../extensions/secondorder/fused_fisher_block/__init__.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/backpack/__init__.py b/backpack/__init__.py index 4f0deb64d..59efd5771 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -83,7 +83,7 @@ def hook_store_io(module, input, output): input: List of input tensors output: output tensor """ - if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)): + if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss)): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) module.output = output @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out): for backpack_extension in CTX.get_active_exts(): if CTX.get_debug(): print("[DEBUG] Running extension", backpack_extension, "on", module) - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss): backpack_extension.apply(module, g_inp, g_out) if not ( diff --git a/backpack/extensions/secondorder/fused_fisher_block/__init__.py b/backpack/extensions/secondorder/fused_fisher_block/__init__.py index d077f5cfc..f97149977 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/__init__.py +++ b/backpack/extensions/secondorder/fused_fisher_block/__init__.py @@ -1,11 +1,13 @@ from torch.nn import ( + CrossEntropyLoss, Linear ) from backpack.extensions.backprop_extension import BackpropExtension from . import ( - linear + linear, + losses ) class FusedFisherBlock(BackpropExtension): @@ -16,6 +18,7 @@ def __init__(self, loss_sample, damping=1.0): savefield="fused_fisher_block", fail_mode="WARNING", module_exts={ + CrossEntropyLoss: losses.FusedFisherBlockCrossEntropyLoss(), Linear: linear.FusedFisherBlockLinear(self.damping) }, ) From 9ec8aceeb24368c330d08d96ff68e81c7738dbe8 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Wed, 25 Aug 2021 17:50:29 -0400 Subject: [PATCH 05/19] remove extra arg. --- backpack/extensions/secondorder/fused_fisher_block/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/__init__.py b/backpack/extensions/secondorder/fused_fisher_block/__init__.py index f97149977..d6265bf55 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/__init__.py +++ b/backpack/extensions/secondorder/fused_fisher_block/__init__.py @@ -11,8 +11,7 @@ ) class FusedFisherBlock(BackpropExtension): - def __init__(self, loss_sample, damping=1.0): - self.loss_sample = loss_sample + def __init__(self, damping=1.0): self.damping = damping super().__init__( savefield="fused_fisher_block", From e37f5610ef1169235c256c23ec7af5828dc32731 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 14:06:08 -0400 Subject: [PATCH 06/19] set mc_samples = 1 for backprop efficiency. --- .../extensions/secondorder/fused_fisher_block/losses.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py index 48c756434..628bc4184 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/losses.py +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -6,15 +6,16 @@ class FusedFisherBlockLoss(FusedFisherBlockBaseModule): def backpropagate(self, ext, module, grad_inp, grad_out, backproped): - # backprop symmetric factorization of the hessian of loss w.r.t. outputs of the network, - # i.e. H = SS^T + # backprop symmetric factorization of the hessian of loss w.r.t. the network outputs, + # i.e. S in H = SS^T hess_func = self.make_loss_hessian_func(ext) return hess_func(module, grad_inp, grad_out) def make_loss_hessian_func(self, ext): # TODO(bmu): try both exact and MC sampling - return self.derivatives.sqrt_hessian + # set mc_samples = 1 for backprop efficiency + return self.derivatives.sqrt_hessian_sampled class FusedFisherBlockCrossEntropyLoss(FusedFisherBlockLoss): From 3030dd6a829c1e42c4fea2392044df38957d980f Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 14:26:02 -0400 Subject: [PATCH 07/19] add linear bias. --- .../secondorder/fused_fisher_block/linear.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index b66997f4c..93c749345 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -8,7 +8,6 @@ class FusedFisherBlockLinear(FusedFisherBlockBaseModule): def __init__(self, damping=1.0): self.damping = damping - self.alpha = alpha super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): @@ -46,25 +45,30 @@ def weight(self, ext, module, g_inp, g_out, backproped): def bias(self, ext, module, g_inp, g_out, backproped): - # TODO(bmu): manully backprop quantities in the extra b/w pass - - grad = module.bias.grad - n = g_out[0].shape[0] - g_out_sc = n * g_out[0] + """ + y = wx + b + g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] + g_out: tuple of [dl/dy (individual, divided by batch size m)] + backproped: + * symmetric factorization of the hessian w.r.t. output, i.e. S in H = SS^T (scaled by 1/sqrt(m)) + * S^{(i-1)} = J^TS^{(i)} + * jacobian of loss w.r.t. bias params = transposed jacobian of output w.r.t. bias params @ S = S + fuse by manully backproping quantities in the extra b/w pass + """ + # derivative of loss w.r.t. bias parameters = derivatie of loss w.r.t. layer output + g = g_inp[0] + m = g_out[0].shape[0] + + J = backproped.squeeze() # compute vector jacobian product in optimization method - grad_prod = einsum("no,o->n", (g_out_sc, grad)) - # grad_prod = 0 - out = einsum("no,lo->nl", g_out_sc, g_out_sc) - # out = 0 + Jg = einsum("mp,p->m", (J, g)) + JTJ = einsum("mp,lp->ml", J, J) + JTJ_inv = inv(JTJ + self.damping * eye(m).to(g.device)) + v = matmul(JTJ_inv, Jg.unsqueeze(1)).squeeze() + gv = einsum("m,mp->p", (v, J)) - NGD_kernel = out / n - NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) - v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - gv = einsum("n,no->o", (v, g_out_sc)) - gv = gv / n - - update = (grad - gv)/self.damping + update = (g - gv) / self.damping return update From 4d4d0418f0f9122d7e3adfecf0841e5dbba237e0 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:23:01 -0400 Subject: [PATCH 08/19] add linear weight. --- .../secondorder/fused_fisher_block/linear.py | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index 93c749345..c11413b35 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -11,36 +11,42 @@ def __init__(self, damping=1.0): super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): - # TODO(bmu): manully backprop quantities in the extra b/w pass + """ + y = wx + b + g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] + g_out: tuple of [dl/dy (individual, divided by batch size m)] + backproped: + * symmetric factorization of the hessian w.r.t. output, i.e. S in H = SS^T (scaled by 1/sqrt(m)) + * S^{(i-1)} = J^TS^{(i)} + * jacobian of loss w.r.t. weight params = transposed jacobian of output w.r.t. weight params @ S + fuse by manully backproping quantities in the extra b/w pass + """ + # derivative of loss w.r.t. weight parameters = transposed derivatie of loss w.r.t. layer output @ I + m = g_out[0].shape[0] I = module.input0 - n = g_out[0].shape[0] - g_out_sc = n * g_out[0] - G = g_out_sc - grad = module.weight.grad - - B = einsum("ni,li->nl", (I, I)) - A = einsum("no,lo->nl", (G, G)) - - # compute vector jacobian product in optimization method - grad_prod = einsum("ni,oi->no", (I, grad)) - grad_prod = einsum("no,no->n", (grad_prod, G)) - # grad_prod = 0 - out = A * B - # out = 0 - NGD_kernel = out / n - NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) - v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() - - gv = einsum("n,no->no", (v, G)) - gv = einsum("no,ni->oi", (gv, I)) - gv = gv / n - - update = (grad - gv)/self.damping + G = backproped.squeeze() # scaled by 1/sqrt(m) + g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) + + # compute the covariance factors II and GG + II = einsum("mi,li->ml", (I, I)) + GG = einsum("mo,lo->ml", (G, G)) + + # ngd update = J^T @ inv(JJ^T + damping * I) @ Jg + Jg = einsum("mi,io->mo", (I, g)) + Jg = einsum("mo,mo->m", (Jg, G)) + JJT = II * GG + JJT_inv = inv(JJT + self.damping * eye(m).to(g.device)) + v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() + gv = einsum("m,mo->mo", (v, G)) + gv = einsum("mo,mi->oi", (gv, I)) + + update = (g.t() - gv) / self.damping module.I = I module.G = G - module.NGD_inv = NGD_inv + module.NGD_inv = JJT_inv + return update @@ -55,17 +61,17 @@ def bias(self, ext, module, g_inp, g_out, backproped): * jacobian of loss w.r.t. bias params = transposed jacobian of output w.r.t. bias params @ S = S fuse by manully backproping quantities in the extra b/w pass """ - # derivative of loss w.r.t. bias parameters = derivatie of loss w.r.t. layer output + # derivative of loss w.r.t. bias parameters = derivatie of loss w.r.t. layer output, i.e. J = G g = g_inp[0] m = g_out[0].shape[0] J = backproped.squeeze() - # compute vector jacobian product in optimization method + # ngd update = J^T @ inv(JJ^T + damping * I) @ Jg Jg = einsum("mp,p->m", (J, g)) - JTJ = einsum("mp,lp->ml", J, J) - JTJ_inv = inv(JTJ + self.damping * eye(m).to(g.device)) - v = matmul(JTJ_inv, Jg.unsqueeze(1)).squeeze() + JJT = einsum("mp,lp->ml", J, J) + JJT_inv = inv(JTJ + self.damping * eye(m).to(g.device)) + v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() gv = einsum("m,mp->p", (v, J)) update = (g - gv) / self.damping From f296694770840c16d7f8019f8ed9b321dfbb42d3 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:35:44 -0400 Subject: [PATCH 09/19] add relu. --- .../secondorder/fused_fisher_block/activations.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 backpack/extensions/secondorder/fused_fisher_block/activations.py diff --git a/backpack/extensions/secondorder/fused_fisher_block/activations.py b/backpack/extensions/secondorder/fused_fisher_block/activations.py new file mode 100644 index 000000000..ed4f2f37a --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/activations.py @@ -0,0 +1,7 @@ +from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockReLU(FusedFisherBlockBaseModule): + def __init__(self): + super().__init__(derivatives=ReLUDerivatives()) From aa499377abe87a8f08e794abd3043c55a97c5a32 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:48:20 -0400 Subject: [PATCH 10/19] add flatten. --- .../secondorder/fused_fisher_block/flatten.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 backpack/extensions/secondorder/fused_fisher_block/flatten.py diff --git a/backpack/extensions/secondorder/fused_fisher_block/flatten.py b/backpack/extensions/secondorder/fused_fisher_block/flatten.py new file mode 100644 index 000000000..5323cce10 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/flatten.py @@ -0,0 +1,13 @@ +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockFlatten(FusedFisherBlockBaseModule): + def __init__(self): + super().__init__(derivatives=FlattenDerivatives()) + + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) From 30d7ba3d7389a4200beefe3ba34bd6e1e111d816 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:58:11 -0400 Subject: [PATCH 11/19] fix typo. --- backpack/extensions/secondorder/fused_fisher_block/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index c11413b35..e9d25d1dc 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -70,7 +70,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): # ngd update = J^T @ inv(JJ^T + damping * I) @ Jg Jg = einsum("mp,p->m", (J, g)) JJT = einsum("mp,lp->ml", J, J) - JJT_inv = inv(JTJ + self.damping * eye(m).to(g.device)) + JJT_inv = inv(JJT + self.damping * eye(m).to(g.device)) v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() gv = einsum("m,mp->p", (v, J)) From b6f738c14e5ea69745da834c5f3c949c64887a65 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Thu, 26 Aug 2021 15:58:37 -0400 Subject: [PATCH 12/19] enable relu and flatten. --- backpack/__init__.py | 4 ++-- .../secondorder/fused_fisher_block/__init__.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/backpack/__init__.py b/backpack/__init__.py index 59efd5771..a278ad9ab 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -83,7 +83,7 @@ def hook_store_io(module, input, output): input: List of input tensors output: output tensor """ - if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss)): + if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss)) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) module.output = output @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out): for backpack_extension in CTX.get_active_exts(): if CTX.get_debug(): print("[DEBUG] Running extension", backpack_extension, "on", module) - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten): backpack_extension.apply(module, g_inp, g_out) if not ( diff --git a/backpack/extensions/secondorder/fused_fisher_block/__init__.py b/backpack/extensions/secondorder/fused_fisher_block/__init__.py index d6265bf55..cb0b8fac7 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/__init__.py +++ b/backpack/extensions/secondorder/fused_fisher_block/__init__.py @@ -1,13 +1,17 @@ from torch.nn import ( CrossEntropyLoss, - Linear + Linear, + ReLU, + Flatten ) from backpack.extensions.backprop_extension import BackpropExtension from . import ( linear, - losses + losses, + activations, + flatten ) class FusedFisherBlock(BackpropExtension): @@ -18,6 +22,8 @@ def __init__(self, damping=1.0): fail_mode="WARNING", module_exts={ CrossEntropyLoss: losses.FusedFisherBlockCrossEntropyLoss(), - Linear: linear.FusedFisherBlockLinear(self.damping) + Linear: linear.FusedFisherBlockLinear(self.damping), + ReLU: activations.FusedFisherBlockReLU(), + Flatten: flatten.FusedFisherBlockFlatten() }, ) From 53b121946fc121c5319adb14c3f3126a3c082b77 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sat, 28 Aug 2021 01:34:56 -0400 Subject: [PATCH 13/19] fix my dumb comments :\ --- .../secondorder/fused_fisher_block/linear.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index e9d25d1dc..704af0027 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -15,24 +15,27 @@ def weight(self, ext, module, g_inp, g_out, backproped): y = wx + b g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] g_out: tuple of [dl/dy (individual, divided by batch size m)] - backproped: - * symmetric factorization of the hessian w.r.t. output, i.e. S in H = SS^T (scaled by 1/sqrt(m)) - * S^{(i-1)} = J^TS^{(i)} - * jacobian of loss w.r.t. weight params = transposed jacobian of output w.r.t. weight params @ S + backproped B: + * [c(number of classes), m(batch size), o(number of outputs)] + * batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where + * J is the Jacobian of network outputs w.r.t. y + * H is the Hessian of loss w.r.t network outputs + * so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T + * backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x + * batched symmetric factorization of GGN/FIM G(w) = transposed Jacobian of output y w.r.t. weight params w @ B fuse by manully backproping quantities in the extra b/w pass """ - # derivative of loss w.r.t. weight parameters = transposed derivatie of loss w.r.t. layer output @ I m = g_out[0].shape[0] I = module.input0 - G = backproped.squeeze() # scaled by 1/sqrt(m) + G = backproped.squeeze() # mc_samples = 1, scaled by 1/sqrt(m) g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) # compute the covariance factors II and GG II = einsum("mi,li->ml", (I, I)) GG = einsum("mo,lo->ml", (G, G)) - # ngd update = J^T @ inv(JJ^T + damping * I) @ Jg + # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg Jg = einsum("mi,io->mo", (I, g)) Jg = einsum("mo,mo->m", (Jg, G)) JJT = II * GG @@ -55,19 +58,22 @@ def bias(self, ext, module, g_inp, g_out, backproped): y = wx + b g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] g_out: tuple of [dl/dy (individual, divided by batch size m)] - backproped: - * symmetric factorization of the hessian w.r.t. output, i.e. S in H = SS^T (scaled by 1/sqrt(m)) - * S^{(i-1)} = J^TS^{(i)} - * jacobian of loss w.r.t. bias params = transposed jacobian of output w.r.t. bias params @ S = S + backproped B: + * [c(number of classes), m(batch size), o(number of outputs)] + * batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where + * J is the Jacobian of network outputs w.r.t. y + * H is the Hessian of loss w.r.t network outputs + * so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T + * backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x + * batched symmetric factorization of GGN/FIM G(b) = transposed Jacobian of output y w.r.t. bias params b @ B = B fuse by manully backproping quantities in the extra b/w pass """ - # derivative of loss w.r.t. bias parameters = derivatie of loss w.r.t. layer output, i.e. J = G g = g_inp[0] m = g_out[0].shape[0] - J = backproped.squeeze() + J = backproped.squeeze() # mc_samples = 1 - # ngd update = J^T @ inv(JJ^T + damping * I) @ Jg + # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg Jg = einsum("mp,p->m", (J, g)) JJT = einsum("mp,lp->ml", J, J) JJT_inv = inv(JJT + self.damping * eye(m).to(g.device)) From e73654a6cdc2695af6617e25df9dc6b57c8df57d Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 02:29:30 -0400 Subject: [PATCH 14/19] try exact hessian of loss w.r.t. network outputs. --- backpack/extensions/secondorder/fused_fisher_block/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py index 628bc4184..917fb01e1 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/losses.py +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -15,7 +15,7 @@ def backpropagate(self, ext, module, grad_inp, grad_out, backproped): def make_loss_hessian_func(self, ext): # TODO(bmu): try both exact and MC sampling # set mc_samples = 1 for backprop efficiency - return self.derivatives.sqrt_hessian_sampled + return self.derivatives.sqrt_hessian class FusedFisherBlockCrossEntropyLoss(FusedFisherBlockLoss): From f496221077c9739440a42584f90f2fa95560a52f Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 02:30:42 -0400 Subject: [PATCH 15/19] try linear with exact hessian of loss w.r.t. output. --- .../secondorder/fused_fisher_block/linear.py | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index 704af0027..3d483a04f 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -25,24 +25,26 @@ def weight(self, ext, module, g_inp, g_out, backproped): * batched symmetric factorization of GGN/FIM G(w) = transposed Jacobian of output y w.r.t. weight params w @ B fuse by manully backproping quantities in the extra b/w pass """ - m = g_out[0].shape[0] - + c, m, o = backproped.size() I = module.input0 - G = backproped.squeeze() # mc_samples = 1, scaled by 1/sqrt(m) + # G = backproped.squeeze() # mc_samples = 1, scaled by 1/sqrt(m) + G = backproped # [c, m, o] g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) # compute the covariance factors II and GG - II = einsum("mi,li->ml", (I, I)) - GG = einsum("mo,lo->ml", (G, G)) + II = einsum("mi,li->ml", (I, I)) # [m, m], memory efficient + GG = einsum("cmo,vlo->cmvl", (G, G)) # [mc, mc] # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg Jg = einsum("mi,io->mo", (I, g)) - Jg = einsum("mo,mo->m", (Jg, G)) - JJT = II * GG - JJT_inv = inv(JJT + self.damping * eye(m).to(g.device)) + Jg = einsum("mo,cmo->cm", (Jg, G)) + Jg = Jg.reshape(-1) + JJT = einsum("mo,cmvo->cmvo", (II, GG)).reshape(c * m, c * m) + JJT_inv = inv(JJT + self.damping * eye(c * m).to(g.device)) v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() - gv = einsum("m,mo->mo", (v, G)) - gv = einsum("mo,mi->oi", (gv, I)) + gv = einsum("q,qo->qo", (v, G.reshape(c * m, o))) + gv = gv.reshape(c, m, o) + gv = einsum("cmo,mi->oi", (gv, I)) update = (g.t() - gv) / self.damping @@ -69,16 +71,17 @@ def bias(self, ext, module, g_inp, g_out, backproped): fuse by manully backproping quantities in the extra b/w pass """ g = g_inp[0] - m = g_out[0].shape[0] - - J = backproped.squeeze() # mc_samples = 1 + c, m, p = backproped.size() + # J = backproped.squeeze() # mc_samples = 1 + J = backproped.reshape(-1, p) # [cm, p] # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg - Jg = einsum("mp,p->m", (J, g)) - JJT = einsum("mp,lp->ml", J, J) - JJT_inv = inv(JJT + self.damping * eye(m).to(g.device)) + Jg = einsum("qp,p->q", (J, g)) + + JJT = einsum("qp,rp->qr", J, J) # [cm, cm] + JJT_inv = inv(JJT + self.damping * eye(m * c).to(g.device)) v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() - gv = einsum("m,mp->p", (v, J)) + gv = einsum("q,qp->p", (v, J)) update = (g - gv) / self.damping From c4c84f0f25c4a6652c03484ce4c344ec6546d6c6 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 04:00:32 -0400 Subject: [PATCH 16/19] backprop H_inv, J of f w.r.t. y, shape (m, c). --- .../extensions/secondorder/fused_fisher_block/losses.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py index 917fb01e1..ad5133e65 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/losses.py +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -1,5 +1,8 @@ from functools import partial +from torch.linalg import inv +from torch import einsum, eye + from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule @@ -9,8 +12,12 @@ def backpropagate(self, ext, module, grad_inp, grad_out, backproped): # backprop symmetric factorization of the hessian of loss w.r.t. the network outputs, # i.e. S in H = SS^T hess_func = self.make_loss_hessian_func(ext) + sqrt_H = hess_func(module, grad_inp, grad_out) + c_, m, c = sqrt_H.size() + H = einsum('omc,olv->cmvl', (sqrt_H, sqrt_H)).reshape(c * m, c * m) + H_inv = inv(H) - return hess_func(module, grad_inp, grad_out) + return (H_inv, eye(c, c * m).to(H_inv.device), (m, c)) def make_loss_hessian_func(self, ext): # TODO(bmu): try both exact and MC sampling From 4f0437e5db2edaa4d4aa5431621defba9ae141ea Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 04:03:07 -0400 Subject: [PATCH 17/19] fix SMW for batched J'HJ. --- .../secondorder/fused_fisher_block/linear.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index 3d483a04f..c0aaa40f5 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -25,26 +25,32 @@ def weight(self, ext, module, g_inp, g_out, backproped): * batched symmetric factorization of GGN/FIM G(w) = transposed Jacobian of output y w.r.t. weight params w @ B fuse by manully backproping quantities in the extra b/w pass """ - c, m, o = backproped.size() I = module.input0 - # G = backproped.squeeze() # mc_samples = 1, scaled by 1/sqrt(m) - G = backproped # [c, m, o] + + # --- I: mc_samples = 1 --- + # G = backproped.squeeze() # scaled by 1/sqrt(m) + + # --- II: exact hessian of loss w.r.t. network outputs --- + H_inv, G, (m, c) = backproped + g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) + o = G.size(0) + G = G.reshape(c, m, o) # compute the covariance factors II and GG II = einsum("mi,li->ml", (I, I)) # [m, m], memory efficient GG = einsum("cmo,vlo->cmvl", (G, G)) # [mc, mc] - # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg + # GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g Jg = einsum("mi,io->mo", (I, g)) Jg = einsum("mo,cmo->cm", (Jg, G)) Jg = Jg.reshape(-1) - JJT = einsum("mo,cmvo->cmvo", (II, GG)).reshape(c * m, c * m) - JJT_inv = inv(JJT + self.damping * eye(c * m).to(g.device)) + JJT = einsum("mo,cmvo->cmvo", (II, GG)).reshape(c * m, c * m) / m + JJT_inv = inv(JJT + self.damping * H_inv) v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() gv = einsum("q,qo->qo", (v, G.reshape(c * m, o))) gv = gv.reshape(c, m, o) - gv = einsum("cmo,mi->oi", (gv, I)) + gv = einsum("cmo,mi->oi", (gv, I)) / m update = (g.t() - gv) / self.damping @@ -71,17 +77,20 @@ def bias(self, ext, module, g_inp, g_out, backproped): fuse by manully backproping quantities in the extra b/w pass """ g = g_inp[0] - c, m, p = backproped.size() - # J = backproped.squeeze() # mc_samples = 1 - J = backproped.reshape(-1, p) # [cm, p] - # ngd update + SMW formula = J^T @ inv(JJ^T + damping * I) @ Jg - Jg = einsum("qp,p->q", (J, g)) + # --- I: mc_samples = 1 --- + # J = backproped.squeeze() + + # --- II: exact hessian of loss w.r.t. network outputs --- + H_inv, J, (m, c) = backproped + + # GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g + Jg = einsum("pq,p->q", (J, g)) # q = cm - JJT = einsum("qp,rp->qr", J, J) # [cm, cm] - JJT_inv = inv(JJT + self.damping * eye(m * c).to(g.device)) + JJT = einsum("pq,pr->qr", J, J) / m # [cm, cm] + JJT_inv = inv(JJT + self.damping * H_inv) v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() - gv = einsum("q,qp->p", (v, J)) + gv = einsum("q,pq->p", (v, J)) / m update = (g - gv) / self.damping From d34c5a51dbd424160eb0979b10801919a38d6a75 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 06:38:50 -0400 Subject: [PATCH 18/19] customize "backpropagate()". --- .../fused_fisher_block/fused_fisher_block_base.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py index a2f1a5746..94c508345 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py +++ b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py @@ -1,6 +1,11 @@ -from backpack.extensions.mat_to_mat_jac_base import MatToJacMat +from backpack.extensions.module_extension import ModuleExtension -class FusedFisherBlockBaseModule(MatToJacMat): +class FusedFisherBlockBaseModule(ModuleExtension): def __init__(self, derivatives, params=None): - super().__init__(derivatives, params=params) + super().__init__(params=params) + self.derivatives = derivatives + + def backpropagate(self, ext, module, g_inp, g_out, backproped): + H_inv, J, (m, c) = backproped + return [H_inv, self.derivatives.jac_t_mat_prod(module, g_inp, g_out, J), (m, c)] From fe519ea37e5a3b28352c69ee3f391a8791d28258 Mon Sep 17 00:00:00 2001 From: nobody <63745715+2zki3@users.noreply.github.com> Date: Sun, 29 Aug 2021 06:40:11 -0400 Subject: [PATCH 19/19] fix J's shape, use vectorized instead of flattened. --- backpack/extensions/secondorder/fused_fisher_block/linear.py | 4 ++-- backpack/extensions/secondorder/fused_fisher_block/losses.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py index c0aaa40f5..7c24f95ec 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/linear.py +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -32,10 +32,9 @@ def weight(self, ext, module, g_inp, g_out, backproped): # --- II: exact hessian of loss w.r.t. network outputs --- H_inv, G, (m, c) = backproped + c, m, o = G.size() g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) - o = G.size(0) - G = G.reshape(c, m, o) # compute the covariance factors II and GG II = einsum("mi,li->ml", (I, I)) # [m, m], memory efficient @@ -83,6 +82,7 @@ def bias(self, ext, module, g_inp, g_out, backproped): # --- II: exact hessian of loss w.r.t. network outputs --- H_inv, J, (m, c) = backproped + J = J.reshape(-1, c * m) # GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g Jg = einsum("pq,p->q", (J, g)) # q = cm diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py index ad5133e65..8c313b20b 100644 --- a/backpack/extensions/secondorder/fused_fisher_block/losses.py +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -17,7 +17,7 @@ def backpropagate(self, ext, module, grad_inp, grad_out, backproped): H = einsum('omc,olv->cmvl', (sqrt_H, sqrt_H)).reshape(c * m, c * m) H_inv = inv(H) - return (H_inv, eye(c, c * m).to(H_inv.device), (m, c)) + return (H_inv, eye(c, c * m).to(H_inv.device).reshape(c, m, c), (m, c)) def make_loss_hessian_func(self, ext): # TODO(bmu): try both exact and MC sampling