diff --git a/backpack/__init__.py b/backpack/__init__.py index 4f0deb64d..5ca3d7829 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -83,7 +83,7 @@ def hook_store_io(module, input, output): input: List of input tensors output: output tensor """ - if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)): + if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d)): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) module.output = output @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out): for backpack_extension in CTX.get_active_exts(): if CTX.get_debug(): print("[DEBUG] Running extension", backpack_extension, "on", module) - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d): backpack_extension.apply(module, g_inp, g_out) if not ( diff --git a/backpack/extensions/firstorder/fisher_block_eff/batchnorm1d.py b/backpack/extensions/firstorder/fisher_block_eff/batchnorm1d.py index 28b42c820..e87f9174e 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/batchnorm1d.py +++ b/backpack/extensions/firstorder/fisher_block_eff/batchnorm1d.py @@ -3,6 +3,7 @@ from torch import einsum, eye, matmul, ones_like, norm from torch.linalg import inv +import torch class FisherBlockEffBatchNorm1d(FisherBlockEffBase): def __init__(self, damping=1.0): @@ -10,14 +11,13 @@ def __init__(self, damping=1.0): super().__init__(derivatives=BatchNorm1dDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): - - return module.weight.grad + update = torch.empty_like(module.weight.grad).copy_(module.weight.grad) + return update def bias(self, ext, module, g_inp, g_out, backproped): - - - return module.bias.grad + update = torch.empty_like(module.bias.grad).copy_(module.bias.grad) + return update diff --git a/backpack/extensions/firstorder/fisher_block_eff/batchnorm2d.py b/backpack/extensions/firstorder/fisher_block_eff/batchnorm2d.py index 0daf659a4..e2adb2295 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/batchnorm2d.py +++ b/backpack/extensions/firstorder/fisher_block_eff/batchnorm2d.py @@ -4,18 +4,20 @@ from torch import einsum, eye, matmul, ones_like, norm from torch.linalg import inv +import torch + class FisherBlockEffBatchNorm2d(FisherBlockEffBase): def __init__(self, damping=1.0): self.damping = damping super().__init__(derivatives=BatchNorm2dDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): - - return module.weight.grad + update = torch.empty_like(module.weight.grad).copy_(module.weight.grad) + return update def bias(self, ext, module, g_inp, g_out, backproped): - - return module.bias.grad + update = torch.empty_like(module.bias.grad).copy_(module.bias.grad) + return update