diff --git a/.DS_Store b/.DS_Store index bd5233c34..2bab8a66a 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/backpack/.DS_Store b/backpack/.DS_Store index eb0087a9f..0054067bf 100644 Binary files a/backpack/.DS_Store and b/backpack/.DS_Store differ diff --git a/backpack/__init__.py b/backpack/__init__.py index 4f0deb64d..98cb741e5 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -83,7 +83,7 @@ def hook_store_io(module, input, output): input: List of input tensors output: output tensor """ - if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)): + if module.training and (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm)): for i in range(len(input)): setattr(module, "input{}".format(i), input[i]) module.output = output @@ -134,7 +134,7 @@ def hook_run_extensions(module, g_inp, g_out): for backpack_extension in CTX.get_active_exts(): if CTX.get_debug(): print("[DEBUG] Running extension", backpack_extension, "on", module) - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm): backpack_extension.apply(module, g_inp, g_out) if not ( diff --git a/backpack/core/derivatives/__init__.py b/backpack/core/derivatives/__init__.py index b9d2f1f71..e610d499d 100644 --- a/backpack/core/derivatives/__init__.py +++ b/backpack/core/derivatives/__init__.py @@ -10,6 +10,7 @@ ConvTranspose3d, CrossEntropyLoss, Dropout, + LayerNorm, LeakyReLU, Linear, LogSigmoid, @@ -23,6 +24,7 @@ BatchNorm2d ) +from .basederivatives import BaseParameterDerivatives from .avgpool2d import AvgPool2DDerivatives from .conv1d import Conv1DDerivatives from .conv_transpose1d import ConvTranspose1DDerivatives @@ -47,6 +49,7 @@ from .batchnorm2d import BatchNorm2dDerivatives derivatives_for = { + LayerNorm: BaseParameterDerivatives, Linear: LinearDerivatives, Conv1d: Conv1DDerivatives, Conv2d: Conv2DDerivatives, diff --git a/backpack/extensions/.DS_Store b/backpack/extensions/.DS_Store index 2659c74d0..cfeeba6f7 100644 Binary files a/backpack/extensions/.DS_Store and b/backpack/extensions/.DS_Store differ diff --git a/backpack/extensions/firstorder/.DS_Store b/backpack/extensions/firstorder/.DS_Store index b3fa9dd27..5e83910de 100644 Binary files a/backpack/extensions/firstorder/.DS_Store and b/backpack/extensions/firstorder/.DS_Store differ diff --git a/backpack/extensions/firstorder/fisher_block_eff/.DS_Store b/backpack/extensions/firstorder/fisher_block_eff/.DS_Store index 3b9c2e10d..3f93e63b9 100644 Binary files a/backpack/extensions/firstorder/fisher_block_eff/.DS_Store and b/backpack/extensions/firstorder/fisher_block_eff/.DS_Store differ diff --git a/backpack/extensions/firstorder/fisher_block_eff/__init__.py b/backpack/extensions/firstorder/fisher_block_eff/__init__.py index fda973749..37b5ef48c 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/__init__.py +++ b/backpack/extensions/firstorder/fisher_block_eff/__init__.py @@ -3,7 +3,8 @@ Conv2d, Linear, BatchNorm1d, - BatchNorm2d + BatchNorm2d, + LayerNorm ) from backpack.extensions.backprop_extension import BackpropExtension @@ -13,28 +14,31 @@ conv2d, linear, batchnorm1d, - batchnorm2d + batchnorm2d, + layernorm ) class FisherBlockEff(BackpropExtension): - def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): + def __init__(self, damping=1.0, alpha=0.95, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): self.gamma = gamma self.damping = damping self.alpha =alpha self.low_rank = low_rank self.memory_efficient = memory_efficient self.super_opt = super_opt + self.save_kernel = save_kernel super().__init__( savefield="fisher_block", fail_mode="WARNING", module_exts={ - Linear: linear.FisherBlockEffLinear(self.damping, self.alpha), + Linear: linear.FisherBlockEffLinear(self.damping, self.alpha, self.save_kernel), Conv1d: conv1d.FisherBlockEffConv1d(self.damping), - Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt), + Conv2d: conv2d.FisherBlockEffConv2d(self.damping, self.low_rank, self.gamma, self.memory_efficient, self.super_opt, self.save_kernel), BatchNorm1d: batchnorm1d.FisherBlockEffBatchNorm1d(self.damping), BatchNorm2d: batchnorm2d.FisherBlockEffBatchNorm2d(self.damping), + LayerNorm: layernorm.FisherBlockEffLayerNorm(self.damping, self.save_kernel) }, ) diff --git a/backpack/extensions/firstorder/fisher_block_eff/conv2d.py b/backpack/extensions/firstorder/fisher_block_eff/conv2d.py index 07b29db8c..9bf3b742e 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/conv2d.py +++ b/backpack/extensions/firstorder/fisher_block_eff/conv2d.py @@ -14,12 +14,13 @@ # import matplotlib.pylab as plt MODE = 0 class FisherBlockEffConv2d(FisherBlockEffBase): - def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false'): + def __init__(self, damping=1.0, low_rank='false', gamma=0.95, memory_efficient='false', super_opt='false', save_kernel='false'): self.damping = damping self.low_rank = low_rank self.gamma = gamma self.memory_efficient = memory_efficient self.super_opt = super_opt + self.save_kernel = save_kernel super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, bpQuantities): @@ -106,7 +107,8 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): gv = einsum("nkm,n->mk", (AX, v)).view_as(grad) /n module.AX = AX - + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel update = (grad - gv)/self.damping return update diff --git a/backpack/extensions/firstorder/fisher_block_eff/layernorm.py b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py new file mode 100644 index 000000000..a7aa0d654 --- /dev/null +++ b/backpack/extensions/firstorder/fisher_block_eff/layernorm.py @@ -0,0 +1,112 @@ +import torch + +from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.extensions.firstorder.fisher_block_eff.fisher_block_eff_base import FisherBlockEffBase + +class FisherBlockEffLayerNorm(FisherBlockEffBase): + def __init__(self, damping=1.0, save_kernel='false'): + self.damping = damping + self.save_kernel = save_kernel + super().__init__(derivatives=BaseParameterDerivatives(), params=["bias", "weight"]) + + def weight(self, ext, module, g_inp, g_out, backproped): + I = module.input0 + assert(len(I.shape) in [2, 4]) # linear or conv + n, c, h, w = I.shape[0], I.shape[1], 1, 1 # input shape = output shape for LayerNorm + g_out_sc = n * g_out[0] + + if len(I.shape) == 4: # conv + h, w = I.shape[2], I.shape[3] + # flatten: [n, c, h * w] + I = I.reshape(n, c, -1) + g_out_sc = g_out_sc.reshape(n, c, -1) + + G = g_out_sc + + grad = module.weight.grad.reshape(-1) + + # compute mean, var over [c, h, w] + if len(I.shape) == 2: + mean = I.mean(dim=-1).unsqueeze(-1) + var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + else: + mean = I.mean((-2, -1), keepdims=True) + var = I.var((-2, -1), unbiased=False, keepdims=True) + + # compute mean, var over [h, w] + # mean = I.mean(dim=-1).unsqueeze(-1) + # var = I.var(dim=-1, unbiased=False).unsqueeze(-1) + + x_hat = (I - mean) / (var + module.eps).sqrt() + + J = g_out_sc * x_hat + + # if len(I.shape) == 2: + # J = g_out_sc * x_hat + # else: + # J = torch.einsum('ncf,ncf->nf', g_out_sc, x_hat) + + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / self.damping + update = update.reshape(module.weight.grad.shape) + + module.I = I + module.G = G + module.NGD_inv = NGD_inv + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel + + return update + + def bias(self, ext, module, g_inp, g_out, backproped): + I = module.input0 + assert(len(I.shape) in [2, 4]) + n, c, h, w = I.shape[0], I.shape[1], 1, 1 + g_out_sc = n * g_out[0] + + if len(I.shape) == 4: + h, w = I.shape[2], I.shape[3] + I = I.reshape(n, c, -1) + g_out_sc = g_out_sc.reshape(n, c, -1) + + G = g_out_sc + + grad = module.bias.grad.reshape(-1) + + J = g_out_sc + + # if len(I.shape) == 2: + # J = g_out_sc + # else: + # J = torch.einsum('ncf->nf', g_out_sc) + + J = J.reshape(J.shape[0], -1) + JJT = torch.matmul(J, J.t()) + + grad_prod = torch.matmul(J, grad) + + NGD_kernel = JJT / n + NGD_inv = torch.linalg.inv(NGD_kernel + self.damping * torch.eye(n).to(grad.device)) + v = torch.matmul(NGD_inv, grad_prod) + + gv = torch.matmul(J.t(), v) / n + + update = (grad - gv) / self.damping + update = update.reshape(module.bias.grad.shape) + + module.NGD_inv = NGD_inv + module.G = G + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel + + return update \ No newline at end of file diff --git a/backpack/extensions/firstorder/fisher_block_eff/linear.py b/backpack/extensions/firstorder/fisher_block_eff/linear.py index c5164a451..92ab4ef42 100644 --- a/backpack/extensions/firstorder/fisher_block_eff/linear.py +++ b/backpack/extensions/firstorder/fisher_block_eff/linear.py @@ -6,9 +6,10 @@ class FisherBlockEffLinear(FisherBlockEffBase): - def __init__(self, damping=1.0, alpha=0.95): + def __init__(self, damping=1.0, alpha=0.95, save_kernel='false'): self.damping = damping self.alpha = alpha + self.save_kernel = save_kernel super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) def weight(self, ext, module, g_inp, g_out, backproped): @@ -42,6 +43,8 @@ def weight(self, ext, module, g_inp, g_out, backproped): module.I = I module.G = G module.NGD_inv = NGD_inv + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel return update @@ -68,6 +71,9 @@ def bias(self, ext, module, g_inp, g_out, backproped): update = (grad - gv)/self.damping # update = grad + if self.save_kernel == 'true': + module.NGD_kernel = NGD_kernel + return update