diff --git a/backpack/__init__.py b/backpack/__init__.py index 4f0deb64d..a278ad9ab 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -83,7 +83,7 @@ def hook_store_io(module, input, output): input: List of input tensors output: output tensor """ - if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)): + if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss)) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) module.output = output @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out): for backpack_extension in CTX.get_active_exts(): if CTX.get_debug(): print("[DEBUG] Running extension", backpack_extension, "on", module) - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten): backpack_extension.apply(module, g_inp, g_out) if not ( diff --git a/backpack/extensions/__init__.py b/backpack/extensions/__init__.py index d6549ea58..91b760de2 100644 --- a/backpack/extensions/__init__.py +++ b/backpack/extensions/__init__.py @@ -15,6 +15,7 @@ DiagGGNMC, DiagHessian, MNGD, + FusedFisherBlock ) __all__ = [ @@ -38,4 +39,5 @@ "DiagGGNMC", "DiagGGN", "DiagHessian", + "FusedFisherBlock" ] diff --git a/backpack/extensions/secondorder/__init__.py b/backpack/extensions/secondorder/__init__.py index 1f7f24ec9..0adabe19d 100644 --- a/backpack/extensions/secondorder/__init__.py +++ b/backpack/extensions/secondorder/__init__.py @@ -23,6 +23,7 @@ from .diag_hessian import DiagHessian from .hbp import HBP, KFAC, KFLR, KFRA from .mngd import MNGD +from .fused_fisher_block import FusedFisherBlock __all__ = [ "MNGD", @@ -35,4 +36,5 @@ "KFLR", "KFRA", "HBP", + "FusedFisherBlock" ] diff --git a/backpack/extensions/secondorder/fused_fisher_block/__init__.py b/backpack/extensions/secondorder/fused_fisher_block/__init__.py new file mode 100644 index 000000000..cb0b8fac7 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/__init__.py @@ -0,0 +1,29 @@ +from torch.nn import ( + CrossEntropyLoss, + Linear, + ReLU, + Flatten +) + +from backpack.extensions.backprop_extension import BackpropExtension + +from . import ( + linear, + losses, + activations, + flatten +) + +class FusedFisherBlock(BackpropExtension): + def __init__(self, damping=1.0): + self.damping = damping + super().__init__( + savefield="fused_fisher_block", + fail_mode="WARNING", + module_exts={ + CrossEntropyLoss: losses.FusedFisherBlockCrossEntropyLoss(), + Linear: linear.FusedFisherBlockLinear(self.damping), + ReLU: activations.FusedFisherBlockReLU(), + Flatten: flatten.FusedFisherBlockFlatten() + }, + ) diff --git a/backpack/extensions/secondorder/fused_fisher_block/activations.py b/backpack/extensions/secondorder/fused_fisher_block/activations.py new file mode 100644 index 000000000..ed4f2f37a --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/activations.py @@ -0,0 +1,7 @@ +from backpack.core.derivatives.relu import ReLUDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockReLU(FusedFisherBlockBaseModule): + def __init__(self): + super().__init__(derivatives=ReLUDerivatives()) diff --git a/backpack/extensions/secondorder/fused_fisher_block/flatten.py b/backpack/extensions/secondorder/fused_fisher_block/flatten.py new file mode 100644 index 000000000..5323cce10 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/flatten.py @@ -0,0 +1,13 @@ +from backpack.core.derivatives.flatten import FlattenDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockFlatten(FusedFisherBlockBaseModule): + def __init__(self): + super().__init__(derivatives=FlattenDerivatives()) + + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + if self.derivatives.is_no_op(module): + return backproped + else: + return super().backpropagate(ext, module, grad_inp, grad_out, backproped) diff --git a/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py new file mode 100644 index 000000000..94c508345 --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/fused_fisher_block_base.py @@ -0,0 +1,11 @@ +from backpack.extensions.module_extension import ModuleExtension + + +class FusedFisherBlockBaseModule(ModuleExtension): + def __init__(self, derivatives, params=None): + super().__init__(params=params) + self.derivatives = derivatives + + def backpropagate(self, ext, module, g_inp, g_out, backproped): + H_inv, J, (m, c) = backproped + return [H_inv, self.derivatives.jac_t_mat_prod(module, g_inp, g_out, J), (m, c)] diff --git a/backpack/extensions/secondorder/fused_fisher_block/linear.py b/backpack/extensions/secondorder/fused_fisher_block/linear.py new file mode 100644 index 000000000..7c24f95ec --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/linear.py @@ -0,0 +1,98 @@ +from torch import einsum, eye, matmul, ones_like, norm +from torch.linalg import inv + +from backpack.core.derivatives.linear import LinearDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockLinear(FusedFisherBlockBaseModule): + def __init__(self, damping=1.0): + self.damping = damping + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) + + def weight(self, ext, module, g_inp, g_out, backproped): + """ + y = wx + b + g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] + g_out: tuple of [dl/dy (individual, divided by batch size m)] + backproped B: + * [c(number of classes), m(batch size), o(number of outputs)] + * batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where + * J is the Jacobian of network outputs w.r.t. y + * H is the Hessian of loss w.r.t network outputs + * so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T + * backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x + * batched symmetric factorization of GGN/FIM G(w) = transposed Jacobian of output y w.r.t. weight params w @ B + fuse by manully backproping quantities in the extra b/w pass + """ + I = module.input0 + + # --- I: mc_samples = 1 --- + # G = backproped.squeeze() # scaled by 1/sqrt(m) + + # --- II: exact hessian of loss w.r.t. network outputs --- + H_inv, G, (m, c) = backproped + c, m, o = G.size() + + g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I)) + + # compute the covariance factors II and GG + II = einsum("mi,li->ml", (I, I)) # [m, m], memory efficient + GG = einsum("cmo,vlo->cmvl", (G, G)) # [mc, mc] + + # GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g + Jg = einsum("mi,io->mo", (I, g)) + Jg = einsum("mo,cmo->cm", (Jg, G)) + Jg = Jg.reshape(-1) + JJT = einsum("mo,cmvo->cmvo", (II, GG)).reshape(c * m, c * m) / m + JJT_inv = inv(JJT + self.damping * H_inv) + v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() + gv = einsum("q,qo->qo", (v, G.reshape(c * m, o))) + gv = gv.reshape(c, m, o) + gv = einsum("cmo,mi->oi", (gv, I)) / m + + update = (g.t() - gv) / self.damping + + module.I = I + module.G = G + module.NGD_inv = JJT_inv + + return update + + + def bias(self, ext, module, g_inp, g_out, backproped): + """ + y = wx + b + g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw] + g_out: tuple of [dl/dy (individual, divided by batch size m)] + backproped B: + * [c(number of classes), m(batch size), o(number of outputs)] + * batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where + * J is the Jacobian of network outputs w.r.t. y + * H is the Hessian of loss w.r.t network outputs + * so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T + * backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x + * batched symmetric factorization of GGN/FIM G(b) = transposed Jacobian of output y w.r.t. bias params b @ B = B + fuse by manully backproping quantities in the extra b/w pass + """ + g = g_inp[0] + + # --- I: mc_samples = 1 --- + # J = backproped.squeeze() + + # --- II: exact hessian of loss w.r.t. network outputs --- + H_inv, J, (m, c) = backproped + J = J.reshape(-1, c * m) + + # GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g + Jg = einsum("pq,p->q", (J, g)) # q = cm + + JJT = einsum("pq,pr->qr", J, J) / m # [cm, cm] + JJT_inv = inv(JJT + self.damping * H_inv) + v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze() + gv = einsum("q,pq->p", (v, J)) / m + + update = (g - gv) / self.damping + + return update + diff --git a/backpack/extensions/secondorder/fused_fisher_block/losses.py b/backpack/extensions/secondorder/fused_fisher_block/losses.py new file mode 100644 index 000000000..8c313b20b --- /dev/null +++ b/backpack/extensions/secondorder/fused_fisher_block/losses.py @@ -0,0 +1,30 @@ +from functools import partial + +from torch.linalg import inv +from torch import einsum, eye + +from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives +from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule + + +class FusedFisherBlockLoss(FusedFisherBlockBaseModule): + def backpropagate(self, ext, module, grad_inp, grad_out, backproped): + # backprop symmetric factorization of the hessian of loss w.r.t. the network outputs, + # i.e. S in H = SS^T + hess_func = self.make_loss_hessian_func(ext) + sqrt_H = hess_func(module, grad_inp, grad_out) + c_, m, c = sqrt_H.size() + H = einsum('omc,olv->cmvl', (sqrt_H, sqrt_H)).reshape(c * m, c * m) + H_inv = inv(H) + + return (H_inv, eye(c, c * m).to(H_inv.device).reshape(c, m, c), (m, c)) + + def make_loss_hessian_func(self, ext): + # TODO(bmu): try both exact and MC sampling + # set mc_samples = 1 for backprop efficiency + return self.derivatives.sqrt_hessian + + +class FusedFisherBlockCrossEntropyLoss(FusedFisherBlockLoss): + def __init__(self): + super().__init__(derivatives=CrossEntropyLossDerivatives())