diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 71896012e183..2b638042b118 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -189,8 +189,8 @@ def rotate_queries_or_keys(x, pos): emb_sin = freq.sin() # (..., N, D/2) emb_cos = freq.cos() # (..., N, D/2) - emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2) - emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2) + emb_sin = emb_sin.repeat(1, 1, 1, 2) + emb_cos = emb_cos.repeat(1, 1, 1, 2) # -- y = x.unflatten(-1, (-1, 2))