Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def hook_store_io(module, input, output):
input: List of input tensors
output: output tensor
"""
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)):
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss)) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten):
for i in range(len(input)):
setattr(module, "input{}".format(i), input[i])
module.output = output
Expand Down Expand Up @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out):
for backpack_extension in CTX.get_active_exts():
if CTX.get_debug():
print("[DEBUG] Running extension", backpack_extension, "on", module)
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.CrossEntropyLoss) or isinstance(module, nn.ReLU) or isinstance(module, nn.Flatten):
backpack_extension.apply(module, g_inp, g_out)

if not (
Expand Down
2 changes: 2 additions & 0 deletions backpack/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DiagGGNMC,
DiagHessian,
MNGD,
FusedFisherBlock
)

__all__ = [
Expand All @@ -38,4 +39,5 @@
"DiagGGNMC",
"DiagGGN",
"DiagHessian",
"FusedFisherBlock"
]
2 changes: 2 additions & 0 deletions backpack/extensions/secondorder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .diag_hessian import DiagHessian
from .hbp import HBP, KFAC, KFLR, KFRA
from .mngd import MNGD
from .fused_fisher_block import FusedFisherBlock

__all__ = [
"MNGD",
Expand All @@ -35,4 +36,5 @@
"KFLR",
"KFRA",
"HBP",
"FusedFisherBlock"
]
29 changes: 29 additions & 0 deletions backpack/extensions/secondorder/fused_fisher_block/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch.nn import (
CrossEntropyLoss,
Linear,
ReLU,
Flatten
)

from backpack.extensions.backprop_extension import BackpropExtension

from . import (
linear,
losses,
activations,
flatten
)

class FusedFisherBlock(BackpropExtension):
def __init__(self, damping=1.0):
self.damping = damping
super().__init__(
savefield="fused_fisher_block",
fail_mode="WARNING",
module_exts={
CrossEntropyLoss: losses.FusedFisherBlockCrossEntropyLoss(),
Linear: linear.FusedFisherBlockLinear(self.damping),
ReLU: activations.FusedFisherBlockReLU(),
Flatten: flatten.FusedFisherBlockFlatten()
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from backpack.core.derivatives.relu import ReLUDerivatives
from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule


class FusedFisherBlockReLU(FusedFisherBlockBaseModule):
def __init__(self):
super().__init__(derivatives=ReLUDerivatives())
13 changes: 13 additions & 0 deletions backpack/extensions/secondorder/fused_fisher_block/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from backpack.core.derivatives.flatten import FlattenDerivatives
from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule


class FusedFisherBlockFlatten(FusedFisherBlockBaseModule):
def __init__(self):
super().__init__(derivatives=FlattenDerivatives())

def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
if self.derivatives.is_no_op(module):
return backproped
else:
return super().backpropagate(ext, module, grad_inp, grad_out, backproped)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from backpack.extensions.module_extension import ModuleExtension


class FusedFisherBlockBaseModule(ModuleExtension):
def __init__(self, derivatives, params=None):
super().__init__(params=params)
self.derivatives = derivatives

def backpropagate(self, ext, module, g_inp, g_out, backproped):
H_inv, J, (m, c) = backproped
return [H_inv, self.derivatives.jac_t_mat_prod(module, g_inp, g_out, J), (m, c)]
98 changes: 98 additions & 0 deletions backpack/extensions/secondorder/fused_fisher_block/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from torch import einsum, eye, matmul, ones_like, norm
from torch.linalg import inv

from backpack.core.derivatives.linear import LinearDerivatives
from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule


class FusedFisherBlockLinear(FusedFisherBlockBaseModule):
def __init__(self, damping=1.0):
self.damping = damping
super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, backproped):
"""
y = wx + b
g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw]
g_out: tuple of [dl/dy (individual, divided by batch size m)]
backproped B:
* [c(number of classes), m(batch size), o(number of outputs)]
* batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where
* J is the Jacobian of network outputs w.r.t. y
* H is the Hessian of loss w.r.t network outputs
* so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T
* backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x
* batched symmetric factorization of GGN/FIM G(w) = transposed Jacobian of output y w.r.t. weight params w @ B
fuse by manully backproping quantities in the extra b/w pass
"""
I = module.input0

# --- I: mc_samples = 1 ---
# G = backproped.squeeze() # scaled by 1/sqrt(m)

# --- II: exact hessian of loss w.r.t. network outputs ---
H_inv, G, (m, c) = backproped
c, m, o = G.size()

g = g_inp[2] # g = dw = einsum("mo,mi->io", (g_out[0], I))

# compute the covariance factors II and GG
II = einsum("mi,li->ml", (I, I)) # [m, m], memory efficient
GG = einsum("cmo,vlo->cmvl", (G, G)) # [mc, mc]

# GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g
Jg = einsum("mi,io->mo", (I, g))
Jg = einsum("mo,cmo->cm", (Jg, G))
Jg = Jg.reshape(-1)
JJT = einsum("mo,cmvo->cmvo", (II, GG)).reshape(c * m, c * m) / m
JJT_inv = inv(JJT + self.damping * H_inv)
v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze()
gv = einsum("q,qo->qo", (v, G.reshape(c * m, o)))
gv = gv.reshape(c, m, o)
gv = einsum("cmo,mi->oi", (gv, I)) / m

update = (g.t() - gv) / self.damping

module.I = I
module.G = G
module.NGD_inv = JJT_inv

return update


def bias(self, ext, module, g_inp, g_out, backproped):
"""
y = wx + b
g_inp: tuple of [dl/db (avg) = sum of dl/dy over batch dim, dl/dx, dl/dw]
g_out: tuple of [dl/dy (individual, divided by batch size m)]
backproped B:
* [c(number of classes), m(batch size), o(number of outputs)]
* batched symmetric factorization of G(y) = J^T H J (scaled by 1/sqrt(m), where
* J is the Jacobian of network outputs w.r.t. y
* H is the Hessian of loss w.r.t network outputs
* so initially the symmetric factorization of Hessian of loss w.r.t. network outputs, i.e. S in H = SS^T
* backpropagation to the previous layer by left multiplying the Jacobian of y w.r.t. x
* batched symmetric factorization of GGN/FIM G(b) = transposed Jacobian of output y w.r.t. bias params b @ B = B
fuse by manully backproping quantities in the extra b/w pass
"""
g = g_inp[0]

# --- I: mc_samples = 1 ---
# J = backproped.squeeze()

# --- II: exact hessian of loss w.r.t. network outputs ---
H_inv, J, (m, c) = backproped
J = J.reshape(-1, c * m)

# GGN/FIM precondition + SMW formula = 1/λ [I - 1/m J'(λH^{−1} + 1/m JJ')^{-1}J]g
Jg = einsum("pq,p->q", (J, g)) # q = cm

JJT = einsum("pq,pr->qr", J, J) / m # [cm, cm]
JJT_inv = inv(JJT + self.damping * H_inv)
v = matmul(JJT_inv, Jg.unsqueeze(1)).squeeze()
gv = einsum("q,pq->p", (v, J)) / m

update = (g - gv) / self.damping

return update

30 changes: 30 additions & 0 deletions backpack/extensions/secondorder/fused_fisher_block/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from functools import partial

from torch.linalg import inv
from torch import einsum, eye

from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives
from backpack.extensions.secondorder.fused_fisher_block.fused_fisher_block_base import FusedFisherBlockBaseModule


class FusedFisherBlockLoss(FusedFisherBlockBaseModule):
def backpropagate(self, ext, module, grad_inp, grad_out, backproped):
# backprop symmetric factorization of the hessian of loss w.r.t. the network outputs,
# i.e. S in H = SS^T
hess_func = self.make_loss_hessian_func(ext)
sqrt_H = hess_func(module, grad_inp, grad_out)
c_, m, c = sqrt_H.size()
H = einsum('omc,olv->cmvl', (sqrt_H, sqrt_H)).reshape(c * m, c * m)
H_inv = inv(H)

return (H_inv, eye(c, c * m).to(H_inv.device).reshape(c, m, c), (m, c))

def make_loss_hessian_func(self, ext):
# TODO(bmu): try both exact and MC sampling
# set mc_samples = 1 for backprop efficiency
return self.derivatives.sqrt_hessian


class FusedFisherBlockCrossEntropyLoss(FusedFisherBlockLoss):
def __init__(self):
super().__init__(derivatives=CrossEntropyLossDerivatives())