Skip to content

Given script might not be compatible with latest Huggingface versions! #7

@NamburiSrinath

Description

@NamburiSrinath

Hi @jbcdnr, @martinjaggi

Thanks for this work, it's quite intuitive and easy to understand :)

When I am trying to run the script provided, I am getting the following error

Code provided -

from transformers import AutoModel
from collaborative_attention import swap_to_collaborative, BERTCollaborativeAdapter
import copy
import torch

model = AutoModel.from_pretrained("bert-base-cased-finetuned-mrpc")

# reparametrize the model with tensor decomposition to use collaborative heads
# decrease dim_shared_query_key to 384 for example to compress the model
collab_model = copy.deepcopy(model)
swap_to_collaborative(collab_model, BERTCollaborativeAdapter, dim_shared_query_key=768)

# check that output is not altered too much
any_input = torch.LongTensor(3, 25).random_(1000, 10000)
collab_model.eval()  # to disable dropout
out_collab = collab_model(any_input)

model.eval()
out_original = model(any_input)

print("Max l1 error: {:.1e}".format((out_collab[0] - out_original[0]).abs().max().item()))
# >>> Max l1 error: 1.9e-06

# You can evaluate the new model, refine tune it or save it.
# We also want to pretrain our collaborative head from scratch (if you were wondering).

Error message --

out_collab = collab_model(any_input)
TypeError: CollaborativeAttention.forward() takes from 2 to 6 positional arguments but 8 were given

My bet would be, the compatibility with versions. Can you please help regarding this?

Happy to share additional details if needed

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions