Skip to content

Typing issue for triangle_multiplicative_update when using cuEquivariance with openfold2 #229

@dxu16

Description

@dxu16

Describe the bug
I am trying to use cuEquivariance with openfold2 (which is on pytorch 2.5 with CUDA 12.4). When running the tests after installation, triangle_multiplicative_update failed to import, with the following traceback:

Traceback (most recent call last):
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/tests/test_cuequivariance.py", line 154, in test_compare_model
    out_repro_mul = model(batch)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 581, in forward
    outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 330, in iteration
    template_embeds = self.embed_templates(
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/model.py", line 170, in embed_templates
    template_embeds = self.template_embedder(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/embedders.py", line 710, in forward
    t = self.template_pair_stack(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 479, in forward
    t, = checkpoint_blocks(
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/utils/checkpointing.py", line 85, in checkpoint_blocks
    return exec(blocks, args)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/utils/checkpointing.py", line 72, in exec
    a = wrap(block(*a))
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 328, in forward
    single = self.tri_mul_out_in(
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/template.py", line 260, in tri_mul_out_in
    tmu_update = self.tri_mul_out(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/triangular_multiplicative_update.py", line 494, in forward
    result = _cuequivariance_triangular_mult(
  File "/scratch4/jgray21/dxu39/projects/AF2Dock/openfold/openfold/model/triangular_multiplicative_update.py", line 82, in _cuequivariance_triangular_mult
    return triangle_multiplicative_update(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_torch/primitives/triangle.py", line 242, in triangle_multiplicative_update
    return f(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/__init__.py", line 75, in triangle_multiplicative_update
    raise Exception(
Exception: Failed to import Triton-based component: triangle_multiplicative_update:
Traceback (most recent call last):
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/__init__.py", line 63, in <module>
    from cuequivariance_ops_torch.triangle_multiplicative_update import (
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/cuequivariance_ops_torch/triangle_multiplicative_update.py", line 187, in <module>
    def _(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 133, in inner
    schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/infer_schema.py", line 106, in infer_schema
    error_fn(
  File "/scratch4/jgray21/dxu39/miniforge3/envs/openfold_test/lib/python3.10/site-packages/torch/_library/infer_schema.py", line 58, in error_fn
    raise ValueError(
ValueError: infer_schema(func): Parameter valid_optional_inputs has unsupported type list[bool]. The valid types are: dict_keys([<class 'torch.Tensor'>, typing.Optional[torch.Tensor], typing.Sequence[torch.Tensor], typing.List[torch.Tensor], typing.Sequence[typing.Optional[torch.Tensor]], typing.List[typing.Optional[torch.Tensor]], <class 'int'>, typing.Optional[int], typing.Sequence[int], typing.List[int], typing.Optional[typing.Sequence[int]], typing.Optional[typing.List[int]], <class 'float'>, typing.Optional[float], typing.Sequence[float], typing.List[float], typing.Optional[typing.Sequence[float]], typing.Optional[typing.List[float]], <class 'bool'>, typing.Optional[bool], typing.Sequence[bool], typing.List[bool], typing.Optional[typing.Sequence[bool]], typing.Optional[typing.List[bool]], <class 'str'>, typing.Optional[str], typing.Union[int, float, bool], typing.Union[int, float, bool, NoneType], typing.Sequence[typing.Union[int, float, bool]], typing.List[typing.Union[int, float, bool]], <class 'torch.dtype'>, typing.Optional[torch.dtype], <class 'torch.device'>, typing.Optional[torch.device]]). Got func with signature (x: torch.Tensor, mask: torch.Tensor, norm_in_weight: torch.Tensor, norm_in_bias: torch.Tensor, p_in_weight: torch.Tensor, p_in_bias: torch.Tensor, g_in_weight: torch.Tensor, g_in_bias: torch.Tensor, norm_out_weight: torch.Tensor, norm_out_bias: torch.Tensor, p_out_weight: torch.Tensor, p_out_bias: torch.Tensor, g_out_weight: torch.Tensor, g_out_bias: torch.Tensor, direction: str, eps: float, precision: int, valid_optional_inputs: list[bool]) -> torch.Tensor)

Please make sure to install triton==3.3.0. Other versions may not work!

Note that I have tried both triton 3.1.0 that comes with pytorch 2.5 and triton 3.3.0 and there is no difference.

To Reproduce
First, follow instructions on openfold2 page to set up the enviroment and additionally install cuequivariance with $ pip install cuequivariance_ops_torch_cu12 cuequivariance_torch.

At this point, when attempting to run tests with cuequivariance, it actually gives a different issue which is that it fails to import libcue_ops.so. But this can be fixed by reinstalling pytorch with pip:

pip install torch==2.5.1
pip install nvidia-cublas-cu12==12.9.1.4 # need to reinstall cublas because pytorch 2.5.1 would downgrade it

Then run openfold2 tests:

scripts/run_unit_tests.sh

Expected behavior
Test scripts should be able to run with triangle_multiplicative_update.

Solution
This problem appears to be about how typing was written in the cuequivariance_ops_torch package for triangle_multiplicative_update. It can be resolved if I patch the code and change all list typing to typing.List in triangle_multiplicative_update.py and attention_pair_bias_torch.py.
I am not sure about the origin of this issue, could be pytorch versioning. But would it be possible for you to adjust how typing is written for the two files mentioned above, so that the package is compatible with openfold2 (which shouldn't cause any issue otherwise I believe)?

GPU HW/SW(please complete the following information):

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
nvidia-smi
Tue Jan  6 18:06:23 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-PCIE-40GB          On  |   00000000:D8:00.0 Off |                    0 |
| N/A   32C    P0             54W /  250W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
pip list
Package                       Version
----------------------------- ----------
absl-py                       2.3.1
aiohappyeyeballs              2.6.1
aiohttp                       3.13.3
aiosignal                     1.4.0
annotated-types               0.7.0
appdirs                       1.4.4
async-timeout                 5.0.1
attrs                         25.4.0
awscli                        2.32.29
awscrt                        0.29.1
biopython                     1.86
Brotli                        1.2.0
certifi                       2026.1.4
charset-normalizer            3.4.4
click                         8.3.1
colorama                      0.4.6
cuequivariance                0.8.0
cuequivariance-ops-cu12       0.8.0
cuequivariance-ops-torch-cu12 0.8.0
cuequivariance-torch          0.8.0
deepspeed                     0.14.5
distro                        1.8.0
DLLogger                      1.1.0
dm-tree                       0.1.6
docker-pycreds                0.4.0
docutils                      0.19
eval_type_backport            0.3.1
filelock                      3.20.2
frozenlist                    1.7.0
fsspec                        2025.12.0
gitdb                         4.0.12
GitPython                     3.1.46
gmpy2                         2.2.1
hjson                         3.1.0
idna                          3.11
ihm                           2.8
Jinja2                        3.1.6
jmespath                      1.0.1
lightning-utilities           0.15.2
MarkupSafe                    3.0.3
ml_collections                1.0.0
modelcif                      0.7
mpmath                        1.3.0
msgpack                       1.1.2
multidict                     6.7.0
networkx                      3.4.2
ninja                         1.13.0
numpy                         2.2.6
nvidia-cublas-cu12            12.9.1.4
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-ml-py                  13.590.44
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
openfold                      2.2.0
OpenMM                        8.2.0
opt_einsum                    3.4.0
packaging                     25.0
pandas                        2.3.3
pdbfixer                      1.12.0
pip                           25.3
platformdirs                  4.5.1
prompt_toolkit                3.0.51
propcache                     0.3.1
protobuf                      6.32.1
psutil                        7.2.1
py-cpuinfo                    9.0.0
pydantic                      2.12.5
pydantic_core                 2.41.5
PySocks                       1.7.1
python-dateutil               2.9.0
pytorch-lightning             2.6.0
pytz                          2025.2
PyYAML                        6.0.3
requests                      2.32.5
ruamel.yaml                   0.17.21
ruamel.yaml.clib              0.2.12
scipy                         1.15.2
sentry-sdk                    2.48.0
setproctitle                  1.3.7
setuptools                    59.5.0
six                           1.17.0
smmap                         5.0.2
sympy                         1.13.1
torch                         2.5.1
torchmetrics                  1.8.2
tqdm                          4.67.1
triton                        3.1.0
typing_extensions             4.15.0
typing-inspection             0.4.2
tzdata                        2025.3
urllib3                       1.26.20
wandb                         0.23.1
wcwidth                       0.2.14
wheel                         0.45.1
yarl                          1.22.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions