-
Notifications
You must be signed in to change notification settings - Fork 0
Adding low-rank covariance #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
…h 8, removed evaluation on affine models, etc etc
…urns the modifed network (tho the network is modified, so this is reduandant)
has been moved to FactConv directory
MuawizChaudhary
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a high level overview, i'll do additional passes.
| help='filename for saved model') | ||
| parser.add_argument('--seed', default=0, type=int, help='seed to use') | ||
| parser.add_argument('--width', type=float, default=1, help='resnet width scale factor') | ||
| parser.add_argument('--spatial_k', type=float, default=1, help='%spatial low-rank') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider removing
FactConv/pytorch_cifar.py
Outdated
| print("Num Learnable Params: ", sum(p.numel() for p in net.parameters() if | ||
| p.requires_grad)) | ||
|
|
||
| run = wandb.init(project="factconv", config=args, group="lowrankplusdiag", name=run_name, dir=wandb_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth logging the number of parameters
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\ | ||
| T_max=args.num_epochs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good job
| @@ -1,201 +0,0 @@ | |||
| import torch | |||
| from torch import Tensor | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this this script removed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question. appears to be a mistake i made during rebasing
|
|
||
|
|
||
| class FactConv2d(nn.Conv2d): | ||
| def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire script, your re-using alot of code. I wonder if you could subclass FactConv extensions using this factconv module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as you mentioned, eventually we should refactor the other fact convs to use the covariance module.
| from typing import Optional, List, Tuple, Union | ||
| import math | ||
|
|
||
| class Covariance(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theres some repeated code here too. I wonder if we could condense this down.
FactConv/models/function_utils.py
Outdated
| if isinstance(module, FactConv2d) or isinstance(module, DiagFactConv2d)\ | ||
| or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ | ||
| or isinstance(module, LowRankPlusDiagFactConv2d): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we had all these factconv variations subclassing fact cov you could just check if it was a factcov instance.
FactConv/models/function_utils.py
Outdated
| or isinstance(module, DiagChanFactConv2d) or isinstance(module,LowRankFactConv2d)\ | ||
| or isinstance(module, LowRankPlusDiagFactConv2d): | ||
| for name, mod in module.named_modules(): | ||
| if isinstance(mod, Covariance): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i think is better to remove the fact conv check and just check if its a covariance and disable gradients that way.
|
Are we ready to merge? |
Adding Covariance module and low-rank covariance experiments