Skip to content

Conversation

@vivianwhite
Copy link
Collaborator

Adding Covariance module and low-rank covariance experiments

Copy link
Collaborator

@MuawizChaudhary MuawizChaudhary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a high level overview, i'll do additional passes.

help='filename for saved model')
parser.add_argument('--seed', default=0, type=int, help='seed to use')
parser.add_argument('--width', type=float, default=1, help='resnet width scale factor')
parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider removing

Comment on lines 88 to 91
print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if
p.requires_grad))

run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth logging the number of parameters

Comment on lines +98 to 100
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\
T_max=args.num_epochs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job

@@ -1,201 +0,0 @@
import torch
from torch import Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this this script removed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. appears to be a mistake i made during rebasing



class FactConv2d(nn.Conv2d):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire script, your re-using alot of code. I wonder if you could subclass FactConv extensions using this factconv module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you mentioned, eventually we should refactor the other fact convs to use the covariance module.

from typing import Optional, List, Tuple, Union
import math

class Covariance(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theres some repeated code here too. I wonder if we could condense this down.

Comment on lines 200 to 202
if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\
or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\
or isinstance(module, LowRankPlusDiagFactConv2d):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we had all these factconv variations subclassing fact cov you could just check if it was a factcov instance.

or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\
or isinstance(module, LowRankPlusDiagFactConv2d):
for name, mod in module.named_modules():
if isinstance(mod, Covariance):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i think is better to remove the fact conv check and just check if its a covariance and disable gradients that way.

@kamdh
Copy link
Contributor

kamdh commented Feb 14, 2025

Are we ready to merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants