-
Notifications
You must be signed in to change notification settings - Fork 24
Description
Describe the bug
I am trying to use cuEquivariance with openfold2 (which is on pytorch 2.5 with CUDA 12.4). When running the tests after installation, triangle_multiplicative_update failed to import, with the following traceback:
Traceback (most recent call last):
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/tests/test_cuequivariance.py", line 154, in test_compare_model
out_repro_mul = model(batch)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 581, in forward
outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 330, in iteration
template_embeds = self.embed_templates(
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 170, in embed_templates
template_embeds = self.template_embedder(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/embedders.py", line 710, in forward
t = self.template_pair_stack(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 479, in forward
t, = checkpoint_blocks(
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/utils/checkpointing.py", line 85, in checkpoint_blocks
return exec(blocks, args)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/utils/checkpointing.py", line 72, in exec
a = wrap(block(*a))
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 328, in forward
single = self.tri_mul_out_in(
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 260, in tri_mul_out_in
tmu_update = self.tri_mul_out(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/triangular_multiplicative_update.py", line 494, in forward
result = _cuequivariance_triangular_mult(
File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/triangular_multiplicative_update.py", line 82, in _cuequivariance_triangular_mult
return triangle_multiplicative_update(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_torch/primitives/triangle.py", line 242, in triangle_multiplicative_update
return f(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/__init__.py", line 75, in triangle_multiplicative_update
raise Exception(
Exception: Failed to import Triton-based component: triangle_multiplicative_update:
Traceback (most recent call last):
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/__init__.py", line 63, in <module>
from cuequivariance_ops_torch.triangle_multiplicative_update import (
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/triangle_multiplicative_update.py", line 187, in <module>
def _(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 133, in inner
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/infer_schema.py", line 106, in infer_schema
error_fn(
File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/infer_schema.py", line 58, in error_fn
raise ValueError(
ValueError: infer_schema(func): Parameter valid_optional_inputs has unsupported type list[bool]. The valid types are: dict_keys([<class 'torch.Tensor'>, typing.Optional[torch.Tensor], typing.Sequence[torch.Tensor], typing.List[torch.Tensor], typing.Sequence[typing.Optional[torch.Tensor]], typing.List[typing.Optional[torch.Tensor]], <class 'int'>, typing.Optional[int], typing.Sequence[int], typing.List[int], typing.Optional[typing.Sequence[int]], typing.Optional[typing.List[int]], <class 'float'>, typing.Optional[float], typing.Sequence[float], typing.List[float], typing.Optional[typing.Sequence[float]], typing.Optional[typing.List[float]], <class 'bool'>, typing.Optional[bool], typing.Sequence[bool], typing.List[bool], typing.Optional[typing.Sequence[bool]], typing.Optional[typing.List[bool]], <class 'str'>, typing.Optional[str], typing.Union[int, float, bool], typing.Union[int, float, bool, NoneType], typing.Sequence[typing.Union[int, float, bool]], typing.List[typing.Union[int, float, bool]], <class 'torch.dtype'>, typing.Optional[torch.dtype], <class 'torch.device'>, typing.Optional[torch.device]]). Got func with signature (x: torch.Tensor, mask: torch.Tensor, norm_in_weight: torch.Tensor, norm_in_bias: torch.Tensor, p_in_weight: torch.Tensor, p_in_bias: torch.Tensor, g_in_weight: torch.Tensor, g_in_bias: torch.Tensor, norm_out_weight: torch.Tensor, norm_out_bias: torch.Tensor, p_out_weight: torch.Tensor, p_out_bias: torch.Tensor, g_out_weight: torch.Tensor, g_out_bias: torch.Tensor, direction: str, eps: float, precision: int, valid_optional_inputs: list[bool]) -> torch.Tensor)
Please make sure to install triton==3.3.0. Other versions may not work!
Note that I have tried both triton 3.1.0 that comes with pytorch 2.5 and triton 3.3.0 and there is no difference.
To Reproduce
First, follow instructions on openfold2 page to set up the enviroment and additionally install cuequivariance with $ pip install cuequivariance_ops_torch_cu12 cuequivariance_torch.
At this point, when attempting to run tests with cuequivariance, it actually gives a different issue which is that it fails to import libcue_ops.so. But this can be fixed by reinstalling pytorch with pip:
pip install torch==2.5.1
pip install nvidia-cublas-cu12==12.9.1.4 # need to reinstall cublas because pytorch 2.5.1 would downgrade it
Then run openfold2 tests:
scripts/run_unit_tests.sh
Expected behavior
Test scripts should be able to run with triangle_multiplicative_update.
Solution
This problem appears to be about how typing was written in the cuequivariance_ops_torch package for triangle_multiplicative_update. It can be resolved if I patch the code and change all list typing to typing.List in triangle_multiplicative_update.py and attention_pair_bias_torch.py.
I am not sure about the origin of this issue, could be pytorch versioning. But would it be possible for you to adjust how typing is written for the two files mentioned above, so that the package is compatible with openfold2 (which shouldn't cause any issue otherwise I believe)?
GPU HW/SW(please complete the following information):
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
nvidia-smi
Tue Jan 6 18:06:23 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-PCIE-40GB On | 00000000:D8:00.0 Off | 0 |
| N/A 32C P0 54W / 250W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
pip list
Package Version
----------------------------- ----------
absl-py 2.3.1
aiohappyeyeballs 2.6.1
aiohttp 3.13.3
aiosignal 1.4.0
annotated-types 0.7.0
appdirs 1.4.4
async-timeout 5.0.1
attrs 25.4.0
awscli 2.32.29
awscrt 0.29.1
biopython 1.86
Brotli 1.2.0
certifi 2026.1.4
charset-normalizer 3.4.4
click 8.3.1
colorama 0.4.6
cuequivariance 0.8.0
cuequivariance-ops-cu12 0.8.0
cuequivariance-ops-torch-cu12 0.8.0
cuequivariance-torch 0.8.0
deepspeed 0.14.5
distro 1.8.0
DLLogger 1.1.0
dm-tree 0.1.6
docker-pycreds 0.4.0
docutils 0.19
eval_type_backport 0.3.1
filelock 3.20.2
frozenlist 1.7.0
fsspec 2025.12.0
gitdb 4.0.12
GitPython 3.1.46
gmpy2 2.2.1
hjson 3.1.0
idna 3.11
ihm 2.8
Jinja2 3.1.6
jmespath 1.0.1
lightning-utilities 0.15.2
MarkupSafe 3.0.3
ml_collections 1.0.0
modelcif 0.7
mpmath 1.3.0
msgpack 1.1.2
multidict 6.7.0
networkx 3.4.2
ninja 1.13.0
numpy 2.2.6
nvidia-cublas-cu12 12.9.1.4
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-ml-py 13.590.44
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
openfold 2.2.0
OpenMM 8.2.0
opt_einsum 3.4.0
packaging 25.0
pandas 2.3.3
pdbfixer 1.12.0
pip 25.3
platformdirs 4.5.1
prompt_toolkit 3.0.51
propcache 0.3.1
protobuf 6.32.1
psutil 7.2.1
py-cpuinfo 9.0.0
pydantic 2.12.5
pydantic_core 2.41.5
PySocks 1.7.1
python-dateutil 2.9.0
pytorch-lightning 2.6.0
pytz 2025.2
PyYAML 6.0.3
requests 2.32.5
ruamel.yaml 0.17.21
ruamel.yaml.clib 0.2.12
scipy 1.15.2
sentry-sdk 2.48.0
setproctitle 1.3.7
setuptools 59.5.0
six 1.17.0
smmap 5.0.2
sympy 1.13.1
torch 2.5.1
torchmetrics 1.8.2
tqdm 4.67.1
triton 3.1.0
typing_extensions 4.15.0
typing-inspection 0.4.2
tzdata 2025.3
urllib3 1.26.20
wandb 0.23.1
wcwidth 0.2.14
wheel 0.45.1
yarl 1.22.0