Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/graphpro/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from graphpro.util.modes import compute_gnm_slow_modes, compute_anm_slow_modes
from graphpro.util.dssp import compute_dssp, DSSP_CLASS
from graphpro.util.polarity import POLARITY_CLASSES, residue_polarity
from graphpro.util.conservation import ConservationScoreClient

class NodeTargetBinaryAttribute(NodeTarget):
""" Binary target, creates a binary one_hot encoding of the property
Expand Down Expand Up @@ -191,4 +192,24 @@ def generate(self, G: Graph, atom_group: AtomGroup):
def encode(self, G: Graph) -> torch.tensor:
secondary = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 'U' for n in G.nodes()]
secondary_class = [DSSP_CLASS.index(p) for p in secondary]
return F.one_hot(torch.tensor(secondary_class, dtype=torch.int64), num_classes=len(DSSP_CLASS)).to(torch.float)
return F.one_hot(torch.tensor(secondary_class, dtype=torch.int64), num_classes=len(DSSP_CLASS)).to(torch.float)

class ConservationScore(NodeAnnotation):
"""Computes the conservation score for each residue in the graph by default shannon conservation is calculated
using the server graphpro.pegerto.com
"""
def __init__(self, attr_name: str = 'cons_shannon'):
""" Attribute name
"""
self.attr_name = attr_name

def generate(self, G: Graph, atom_group: AtomGroup):
scorer = ConservationScoreClient(str(G), G.metadata.chain)

for resid, score in scorer.compute_conservation_score():
node_id = G.get_node_by_resid(resid)
G.node_attr_add(node_id,self.attr_name, score)

def encode(self, G: Graph) -> torch.tensor:
scores = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 0 for n in G.nodes()]
return F.normalize(torch.tensor([scores], dtype=torch.float).T, dim=(0,1))
8 changes: 5 additions & 3 deletions src/graphpro/graphgen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from .graph import Graph
from .graph import Graph, ProteinMetadata
from .collection import GraphCollection


Expand Down Expand Up @@ -28,7 +28,8 @@ def generate(self, ag, name: str):
ca_position = ag.c_alphas_positions(self.chain)
dist = distance.squareform(distance.pdist(ca_position))
dist[dist > self.cutoff] = 0
return Graph(name, dist, ca_position, ag.c_alphas_residues(self.chain))
metadata = ProteinMetadata(uniprot_id=None, chain=self.chain)
return Graph(name, dist, ca_position, ag.c_alphas_residues(self.chain), metadata)

class KNN(RepresentationMethod):
""" Generate the structure form a defined number of neighbours
Expand All @@ -49,7 +50,8 @@ def generate(self, ag, name: str):
for j in neig:
adjacency[i,j] = 1

return Graph(name, adjacency, ca_position, ag.c_alphas_residues(self.chain))
metadata = ProteinMetadata(uniprot_id=None, chain=self.chain)
return Graph(name, adjacency, ca_position, ag.c_alphas_residues(self.chain), metadata)



Expand Down
16 changes: 16 additions & 0 deletions src/graphpro/util/conservation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class ConservationScoreClient:
def __init__(self, pdb_id: str, chain_id: str, server: str = 'graphpro.pegerto.com'):
self.pdb_id = pdb_id
self.chain_id = chain_id
self.server = server

def compute_conservation_score(self) -> list[tuple[int,float]]:
import requests

resouce = f'https://{self.server}/conservation/{self.pdb_id}/{self.chain_id}'
resp = requests.post(resouce, verify=False)
if resp.status_code != 200:
raise Exception(f"Error calculating conservation: {resp.status_code}")

results = resp.json()['chains'][self.chain_id]
return [(int(resid), float(score['shannon'])) for resid, score in results.items()]
28 changes: 28 additions & 0 deletions test/graphpro/annnotations/conformation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import MDAnalysis as mda

from graphpro import md_analisys
from graphpro.graphgen import ContactMap
from graphpro.annotations import ConservationScore


u3aw0 = mda.Universe(
os.path.dirname(
os.path.realpath(__file__)) +
'/../../testdata/4aw0.pdb')


def test_conservation():
G = md_analisys(u3aw0, '4AW0').generate(ContactMap(cutoff=6, chain='A'), [ConservationScore()])
assert len(G.nodes()) == 283
print(G.node_attr(0))
print(G.node_attr(1))
print(G.node_attr(2))
assert G.node_attr(0)['cons_shannon'] == 0.8285890860167484
assert G.node_attr(1)['cons_shannon'] == 0.8487017361472784
assert G.node_attr(2)['cons_shannon'] == 1.330480269847359

def test_conservation_encoding():
G = md_analisys(u3aw0, '4AW0').generate(ContactMap(cutoff=6, chain='A'), [ConservationScore()])
data = G.to_data(node_encoders=[ConservationScore()])
assert data.x.size() == (283, 1)
10 changes: 9 additions & 1 deletion test/graphpro/graphgen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from MDAnalysis.tests.datafiles import PDB, XTC

u1 = mda.Universe(PDB, XTC)

FIVEHTC = mda.Universe(
os.path.dirname(
os.path.realpath(__file__)) +
'/../testdata/5htc.pdb')

def test_graph_generation_from_mdanalysis():
G = md_analisys(u1).generate(ContactMap(cutoff=6))
Expand All @@ -25,6 +28,11 @@ def test_graph_generation_from_mdanalysis_custom_residue():
assert(len(G.nodes()) == 5752)


def test_graph_generation_from_mdanalysis_with_chain_in_metadata():
G = md_analisys(FIVEHTC).generate(ContactMap(cutoff=6, chain='A'))
assert(len(G.nodes()) == 328)
assert G.metadata.chain == 'A'

def test_graph_generation_collection():
graph_col = md_analisys(u1).generate_trajectory(ContactMap(cutoff=6))

Expand Down
Loading