Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def hook_store_io(module, input, output):
input: List of input tensors
output: output tensor
"""
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)):
if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d)):
for i in range(len(input)):
setattr(module, "input{}".format(i), input[i])
module.output = output
Expand Down Expand Up @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out):
for backpack_extension in CTX.get_active_exts():
if CTX.get_debug():
print("[DEBUG] Running extension", backpack_extension, "on", module)
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
backpack_extension.apply(module, g_inp, g_out)

if not (
Expand Down
10 changes: 5 additions & 5 deletions backpack/extensions/firstorder/fisher_block_eff/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

from torch import einsum, eye, matmul, ones_like, norm
from torch.linalg import inv
import torch

class FisherBlockEffBatchNorm1d(FisherBlockEffBase):
def __init__(self, damping=1.0):
self.damping = damping
super().__init__(derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, backproped):

return module.weight.grad
update = torch.empty_like(module.weight.grad).copy_(module.weight.grad)
return update


def bias(self, ext, module, g_inp, g_out, backproped):


return module.bias.grad
update = torch.empty_like(module.bias.grad).copy_(module.bias.grad)
return update



10 changes: 6 additions & 4 deletions backpack/extensions/firstorder/fisher_block_eff/batchnorm2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from torch import einsum, eye, matmul, ones_like, norm
from torch.linalg import inv

import torch

class FisherBlockEffBatchNorm2d(FisherBlockEffBase):
def __init__(self, damping=1.0):
self.damping = damping
super().__init__(derivatives=BatchNorm2dDerivatives(), params=["bias", "weight"])

def weight(self, ext, module, g_inp, g_out, backproped):

return module.weight.grad
update = torch.empty_like(module.weight.grad).copy_(module.weight.grad)
return update


def bias(self, ext, module, g_inp, g_out, backproped):

return module.bias.grad
update = torch.empty_like(module.bias.grad).copy_(module.bias.grad)
return update