-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Hi @jbcdnr, @martinjaggi
Thanks for this work, it's quite intuitive and easy to understand :)
When I am trying to run the script provided, I am getting the following error
Code provided -
from transformers import AutoModel
from collaborative_attention import swap_to_collaborative, BERTCollaborativeAdapter
import copy
import torch
model = AutoModel.from_pretrained("bert-base-cased-finetuned-mrpc")
# reparametrize the model with tensor decomposition to use collaborative heads
# decrease dim_shared_query_key to 384 for example to compress the model
collab_model = copy.deepcopy(model)
swap_to_collaborative(collab_model, BERTCollaborativeAdapter, dim_shared_query_key=768)
# check that output is not altered too much
any_input = torch.LongTensor(3, 25).random_(1000, 10000)
collab_model.eval() # to disable dropout
out_collab = collab_model(any_input)
model.eval()
out_original = model(any_input)
print("Max l1 error: {:.1e}".format((out_collab[0] - out_original[0]).abs().max().item()))
# >>> Max l1 error: 1.9e-06
# You can evaluate the new model, refine tune it or save it.
# We also want to pretrain our collaborative head from scratch (if you were wondering).
Error message --
out_collab = collab_model(any_input)
TypeError: CollaborativeAttention.forward() takes from 2 to 6 positional arguments but 8 were given
My bet would be, the compatibility with versions. Can you please help regarding this?
Happy to share additional details if needed
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels