-
Notifications
You must be signed in to change notification settings - Fork 24
Open
Description
Not sure if this is an intended behavior or a bug, but cuequivariance_torch.attention_pair_bias returns a placeholder NoneType even when return_z_proj=False. This behavior doesn't match the example in the docs.
From the docs: https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.attention_pair_bias.html
import torch
import cuequivariance
from cuequivariance_torch import attention_pair_bias
# Print torch and cuda version
print("Torch version:", torch.__version__)
print("cuEquivariance version:", cuequivariance.__version__)
if torch.cuda.is_available():
device = torch.device("cuda")
batch_size, seq_len, num_heads, heads_dim, hidden_dim = 1, 32, 2, 32, 64
query_len, key_len, z_dim = 32, 32, 16
# Create input tensors on GPU
s = torch.randn(batch_size, seq_len, hidden_dim,
device=device, dtype=torch.bfloat16)
q = torch.randn(batch_size, num_heads, query_len, heads_dim,
device=device, dtype=torch.bfloat16)
k = torch.randn(batch_size, num_heads, key_len, heads_dim,
device=device, dtype=torch.bfloat16)
v = torch.randn(batch_size, num_heads, key_len, heads_dim,
device=device, dtype=torch.bfloat16)
z = torch.randn(batch_size, query_len, key_len, z_dim,
device=device, dtype=torch.bfloat16)
mask = torch.rand(batch_size, key_len,
device=device) < 0.5
w_proj_z = torch.randn(num_heads, z_dim,
device=device, dtype=torch.bfloat16)
w_proj_g = torch.randn(hidden_dim, hidden_dim,
device=device, dtype=torch.bfloat16)
w_proj_o = torch.randn(hidden_dim, hidden_dim,
device=device, dtype=torch.bfloat16)
w_ln_z = torch.randn(z_dim,
device=device, dtype=torch.bfloat16)
b_ln_z = torch.randn(z_dim,
device=device, dtype=torch.bfloat16)
# Perform operation
output = attention_pair_bias(
s=s,
q=q,
k=k,
v=v,
z=z,
mask=mask,
num_heads=num_heads,
w_proj_z=w_proj_z,
w_proj_g=w_proj_g,
w_proj_o=w_proj_o,
w_ln_z=w_ln_z,
b_ln_z=b_ln_z,
return_z_proj=False,
)
print(output[0].shape, output[1]) # Not from docs
print(output.shape) # torch.Size([1, 32, 64])outputs:
Torch version: 2.8.0+cu129
cuEquivariance version: 0.8.0
torch.Size([1, 32, 64]) None
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[5], [line 53](vscode-notebook-cell:?execution_count=5&line=53)
37 output = attention_pair_bias(
38 s=s,
39 q=q,
(...) 50 return_z_proj=False,
51 )
52 print(output[0].shape, output[1])
---> [53](vscode-notebook-cell:?execution_count=5&line=53) print(output.shape) # torch.Size([1, 32, 64])
AttributeError: 'tuple' object has no attribute 'shape'Metadata
Metadata
Assignees
Labels
No labels