Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .DS_Store
Binary file not shown.
Binary file modified backpack/.DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def hook_store_io(module, input, output):
input: List of input tensors
output: output tensor
"""
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)):
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm)):
for i in range(len(input)):
setattr(module, "input{}".format(i), input[i])
module.output = output
Expand Down Expand Up @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out):
for backpack_extension in CTX.get_active_exts():
if CTX.get_debug():
print("[DEBUG] Running extension", backpack_extension, "on", module)
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm):
backpack_extension.apply(module, g_inp, g_out)

if not (
Expand Down
3 changes: 3 additions & 0 deletions backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ConvTranspose3d,
CrossEntropyLoss,
Dropout,
LayerNorm,
LeakyReLU,
Linear,
LogSigmoid,
Expand All @@ -23,6 +24,7 @@
BatchNorm2d
)

from .basederivatives import BaseParameterDerivatives
from .avgpool2d import AvgPool2DDerivatives
from .conv1d import Conv1DDerivatives
from .conv_transpose1d import ConvTranspose1DDerivatives
Expand All @@ -47,6 +49,7 @@
from .batchnorm2d import BatchNorm2dDerivatives

derivatives_for = {
LayerNorm: BaseParameterDerivatives,
Linear: LinearDerivatives,
Conv1d: Conv1DDerivatives,
Conv2d: Conv2DDerivatives,
Expand Down
Binary file modified backpack/extensions/.DS_Store
Binary file not shown.
Binary file modified backpack/extensions/firstorder/.DS_Store
Binary file not shown.
Binary file modified backpack/extensions/firstorder/fisher_block_eff/.DS_Store
Binary file not shown.
14 changes: 9 additions & 5 deletions backpack/extensions/firstorder/fisher_block_eff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
Conv2d,
Linear,
BatchNorm1d,
BatchNorm2d
BatchNorm2d,
LayerNorm
)

from backpack.extensions.backprop_extension import BackpropExtension
Expand All @@ -13,28 +14,31 @@
conv2d,
linear,
batchnorm1d,
batchnorm2d
batchnorm2d,
layernorm
)


class FisherBlockEff(BackpropExtension):


def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'):
def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'):
self.gamma = gamma
self.damping = damping
self.alpha =alpha
self.low_rank = low_rank
self.memory_efficient = memory_efficient
self.super_opt = super_opt
self.save_kernel = save_kernel
super().__init__(
savefield="fisher_block",
fail_mode="WARNING",
module_exts={
Linear: linear.FisherBlockEffLinear(self.damping, self.alpha),
Linear: linear.FisherBlockEffLinear(self.damping, self.alpha, self.save_kernel),
Conv1d: conv1d.FisherBlockEffConv1d(self.damping),
Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt),
Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt, self.save_kernel),
BatchNorm1d: batchnorm1d.FisherBlockEffBatchNorm1d(self.damping),
BatchNorm2d: batchnorm2d.FisherBlockEffBatchNorm2d(self.damping),
LayerNorm: layernorm.FisherBlockEffLayerNorm(self.damping, self.save_kernel)
},
)
6 changes: 4 additions & 2 deletions backpack/extensions/firstorder/fisher_block_eff/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# import matplotlib.pylab as plt
MODE = 0
class FisherBlockEffConv2d(FisherBlockEffBase):
def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'):
def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'):
self.damping = damping
self.low_rank = low_rank
self.gamma = gamma
self.memory_efficient = memory_efficient
self.super_opt = super_opt
self.save_kernel = save_kernel
super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, bpQuantities):
Expand Down Expand Up @@ -106,7 +107,8 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities):
gv = einsum("nkm,n->mk", (AX, v)).view_as(grad) /n
module.AX = AX


if self.save_kernel == 'true':
module.NGD_kernel = NGD_kernel

update = (grad - gv)/self.damping
return update
Expand Down
112 changes: 112 additions & 0 deletions backpack/extensions/firstorder/fisher_block_eff/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.extensions.firstorder.fisher_block_eff.fisher_block_eff_base import FisherBlockEffBase

class FisherBlockEffLayerNorm(FisherBlockEffBase):
def __init__(self, damping=1.0, save_kernel='false'):
self.damping = damping
self.save_kernel = save_kernel
super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, backproped):
I = module.input0
assert(len(I.shape) in [2, 4]) # linear or conv
n, c, h, w = I.shape[0], I.shape[1], 1, 1 # input shape = output shape for LayerNorm
g_out_sc = n * g_out[0]

if len(I.shape) == 4: # conv
h, w = I.shape[2], I.shape[3]
# flatten: [n, c, h * w]
I = I.reshape(n, c, -1)
g_out_sc = g_out_sc.reshape(n, c, -1)

G = g_out_sc

grad = module.weight.grad.reshape(-1)

# compute mean, var over [c, h, w]
if len(I.shape) == 2:
mean = I.mean(dim=-1).unsqueeze(-1)
var = I.var(dim=-1, unbiased=False).unsqueeze(-1)
else:
mean = I.mean((-2, -1), keepdims=True)
var = I.var((-2, -1), unbiased=False, keepdims=True)

# compute mean, var over [h, w]
# mean = I.mean(dim=-1).unsqueeze(-1)
# var = I.var(dim=-1, unbiased=False).unsqueeze(-1)

x_hat = (I - mean) / (var + module.eps).sqrt()

J = g_out_sc * x_hat

# if len(I.shape) == 2:
# J = g_out_sc * x_hat
# else:
# J = torch.einsum('ncf,ncf->nf', g_out_sc, x_hat)

J = J.reshape(J.shape[0], -1)
JJT = torch.matmul(J, J.t())

grad_prod = torch.matmul(J, grad)

NGD_kernel = JJT / n
NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device))
v = torch.matmul(NGD_inv, grad_prod)

gv = torch.matmul(J.t(), v) / n

update = (grad - gv) / self.damping
update = update.reshape(module.weight.grad.shape)

module.I = I
module.G = G
module.NGD_inv = NGD_inv
if self.save_kernel == 'true':
module.NGD_kernel = NGD_kernel

return update

def bias(self, ext, module, g_inp, g_out, backproped):
I = module.input0
assert(len(I.shape) in [2, 4])
n, c, h, w = I.shape[0], I.shape[1], 1, 1
g_out_sc = n * g_out[0]

if len(I.shape) == 4:
h, w = I.shape[2], I.shape[3]
I = I.reshape(n, c, -1)
g_out_sc = g_out_sc.reshape(n, c, -1)

G = g_out_sc

grad = module.bias.grad.reshape(-1)

J = g_out_sc

# if len(I.shape) == 2:
# J = g_out_sc
# else:
# J = torch.einsum('ncf->nf', g_out_sc)

J = J.reshape(J.shape[0], -1)
JJT = torch.matmul(J, J.t())

grad_prod = torch.matmul(J, grad)

NGD_kernel = JJT / n
NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device))
v = torch.matmul(NGD_inv, grad_prod)

gv = torch.matmul(J.t(), v) / n

update = (grad - gv) / self.damping
update = update.reshape(module.bias.grad.shape)

module.NGD_inv = NGD_inv
module.G = G
if self.save_kernel == 'true':
module.NGD_kernel = NGD_kernel

return update
8 changes: 7 additions & 1 deletion backpack/extensions/firstorder/fisher_block_eff/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@


class FisherBlockEffLinear(FisherBlockEffBase):
def __init__(self, damping=1.0, alpha=0.95):
def __init__(self, damping=1.0, alpha=0.95, save_kernel='false'):
self.damping = damping
self.alpha = alpha
self.save_kernel = save_kernel
super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, backproped):
Expand Down Expand Up @@ -42,6 +43,8 @@ def weight(self, ext, module, g_inp, g_out, backproped):
module.I = I
module.G = G
module.NGD_inv = NGD_inv
if self.save_kernel == 'true':
module.NGD_kernel = NGD_kernel
return update


Expand All @@ -68,6 +71,9 @@ def bias(self, ext, module, g_inp, g_out, backproped):
update = (grad - gv)/self.damping
# update = grad

if self.save_kernel == 'true':
module.NGD_kernel = NGD_kernel

return update