Skip to content

Commit 232a104

Browse files
authored
Ishaan/infer oxygen (#283)
1 parent 23b084b commit 232a104

File tree

10 files changed

+134
-148
lines changed

10 files changed

+134
-148
lines changed

cookbook/tutorials/2_embed.ipynb

Lines changed: 52 additions & 95 deletions
Large diffs are not rendered by default.

esm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.2.3"
1+
__version__ = "3.2.4.a0"

esm/sdk/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def from_protein_chain(
7777
sasa=protein_chain.sasa().tolist(),
7878
function_annotations=None,
7979
coordinates=torch.tensor(protein_chain.atom37_positions),
80+
plddt=torch.tensor(protein_chain.confidence),
8081
)
8182
else:
8283
return ESMProtein(
@@ -85,6 +86,7 @@ def from_protein_chain(
8586
sasa=None,
8687
function_annotations=None,
8788
coordinates=torch.tensor(protein_chain.atom37_positions),
89+
plddt=torch.tensor(protein_chain.confidence),
8890
)
8991

9092
@classmethod
@@ -104,6 +106,7 @@ def from_protein_complex(
104106
coordinates=torch.tensor(
105107
protein_complex.atom37_positions, dtype=torch.float32
106108
),
109+
plddt=torch.tensor(protein_complex.confidence),
107110
)
108111

109112
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
@@ -325,7 +328,9 @@ def use_generative_unmasking_strategy(self):
325328
@define
326329
class InverseFoldingConfig:
327330
invalid_ids: Sequence[int] = []
328-
temperature: float = 1.0
331+
temperature: float = 0.1
332+
seed: int | None = None
333+
decode_in_residue_index_order: bool = False
329334

330335

331336
## Low Level Endpoint Types

esm/sdk/forge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def process_inverse_fold_request(
119119
inverse_folding_config = {
120120
"invalid_ids": config.invalid_ids,
121121
"temperature": config.temperature,
122+
"seed": config.seed,
123+
"decode_in_residue_index_order": config.decode_in_residue_index_order,
122124
}
123125
request = {
124126
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),

esm/utils/structure/molecular_complex.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,9 @@ def to_mmcif(self) -> str:
707707
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
708708
atom_array.res_name = np.array(atom_res_names, dtype="U4")
709709
atom_array.hetero = atom_hetero
710-
atom_array.b_factor = atom_bfactors
711710
atom_array.atom_name = np.array(atom_names, dtype="U4")
711+
atom_array.add_annotation("b_factor", dtype=float)
712+
atom_array.b_factor = atom_bfactors
712713

713714
# Use existing elements or infer them from atom names
714715
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:

esm/utils/structure/protein_chain.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,9 @@ def normalize_coordinates(self) -> ProteinChain:
11211121

11221122
def infer_oxygen(self) -> ProteinChain:
11231123
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
1124-
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
1124+
O_missing_indices = np.argwhere(
1125+
~np.isfinite(self.atoms["O"]).all(axis=1)
1126+
).squeeze()
11251127

11261128
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
11271129
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

esm/utils/structure/protein_complex.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,9 @@ def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
562562

563563
def infer_oxygen(self) -> ProteinComplex:
564564
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
565-
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
565+
O_missing_indices = np.argwhere(
566+
~np.isfinite(self.atoms["O"]).all(axis=1)
567+
).squeeze()
566568

567569
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
568570
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

pixi.lock

Lines changed: 58 additions & 45 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "esm"
3-
version = "3.2.3"
3+
version = "3.2.4.a0"
44
description = "EvolutionaryScale open model repository"
55
readme = "README.md"
66
requires-python = ">=3.12,<3.13"
@@ -24,7 +24,7 @@ dependencies = [
2424
"torch>=2.2.0",
2525
"torchvision",
2626
"torchtext",
27-
"transformers<4.48.2",
27+
"transformers==4.52.4",
2828
"ipython",
2929
"einops",
3030
"biotite>=1.0.0",

tests/Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ DOCKER_TAG ?= dev
33
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}
44

55
build-oss-ci:
6-
docker build -f oss_pytests/Dockerfile oss_pytests -t $(DOCKER_IMAGE_OSS)
6+
docker build \
7+
--output=type=docker \
8+
-f oss_pytests/Dockerfile \
9+
-t $(DOCKER_IMAGE_OSS) \
10+
oss_pytests
711

812
start-docker-oss:
913
docker run \

0 commit comments

Comments
 (0)