diff --git a/src/graphpro/annotations.py b/src/graphpro/annotations.py index d39d984..fd8d2e1 100644 --- a/src/graphpro/annotations.py +++ b/src/graphpro/annotations.py @@ -11,6 +11,7 @@ from graphpro.util.modes import compute_gnm_slow_modes, compute_anm_slow_modes from graphpro.util.dssp import compute_dssp, DSSP_CLASS from graphpro.util.polarity import POLARITY_CLASSES, residue_polarity +from graphpro.util.conservation import ConservationScoreClient class NodeTargetBinaryAttribute(NodeTarget): """ Binary target, creates a binary one_hot encoding of the property @@ -191,4 +192,24 @@ def generate(self, G: Graph, atom_group: AtomGroup): def encode(self, G: Graph) -> torch.tensor: secondary = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 'U' for n in G.nodes()] secondary_class = [DSSP_CLASS.index(p) for p in secondary] - return F.one_hot(torch.tensor(secondary_class, dtype=torch.int64), num_classes=len(DSSP_CLASS)).to(torch.float) \ No newline at end of file + return F.one_hot(torch.tensor(secondary_class, dtype=torch.int64), num_classes=len(DSSP_CLASS)).to(torch.float) + +class ConservationScore(NodeAnnotation): + """Computes the conservation score for each residue in the graph by default shannon conservation is calculated + using the server graphpro.pegerto.com + """ + def __init__(self, attr_name: str = 'cons_shannon'): + """ Attribute name + """ + self.attr_name = attr_name + + def generate(self, G: Graph, atom_group: AtomGroup): + scorer = ConservationScoreClient(str(G), G.metadata.chain) + + for resid, score in scorer.compute_conservation_score(): + node_id = G.get_node_by_resid(resid) + G.node_attr_add(node_id,self.attr_name, score) + + def encode(self, G: Graph) -> torch.tensor: + scores = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 0 for n in G.nodes()] + return F.normalize(torch.tensor([scores], dtype=torch.float).T, dim=(0,1)) \ No newline at end of file diff --git a/src/graphpro/graphgen.py b/src/graphpro/graphgen.py index 50b8cfe..e0c695f 100644 --- a/src/graphpro/graphgen.py +++ b/src/graphpro/graphgen.py @@ -1,5 +1,5 @@ import numpy as np -from .graph import Graph +from .graph import Graph, ProteinMetadata from .collection import GraphCollection @@ -28,7 +28,8 @@ def generate(self, ag, name: str): ca_position = ag.c_alphas_positions(self.chain) dist = distance.squareform(distance.pdist(ca_position)) dist[dist > self.cutoff] = 0 - return Graph(name, dist, ca_position, ag.c_alphas_residues(self.chain)) + metadata = ProteinMetadata(uniprot_id=None, chain=self.chain) + return Graph(name, dist, ca_position, ag.c_alphas_residues(self.chain), metadata) class KNN(RepresentationMethod): """ Generate the structure form a defined number of neighbours @@ -49,7 +50,8 @@ def generate(self, ag, name: str): for j in neig: adjacency[i,j] = 1 - return Graph(name, adjacency, ca_position, ag.c_alphas_residues(self.chain)) + metadata = ProteinMetadata(uniprot_id=None, chain=self.chain) + return Graph(name, adjacency, ca_position, ag.c_alphas_residues(self.chain), metadata) diff --git a/src/graphpro/util/conservation.py b/src/graphpro/util/conservation.py new file mode 100644 index 0000000..605d94e --- /dev/null +++ b/src/graphpro/util/conservation.py @@ -0,0 +1,16 @@ +class ConservationScoreClient: + def __init__(self, pdb_id: str, chain_id: str, server: str = 'graphpro.pegerto.com'): + self.pdb_id = pdb_id + self.chain_id = chain_id + self.server = server + + def compute_conservation_score(self) -> list[tuple[int,float]]: + import requests + + resouce = f'https://{self.server}/conservation/{self.pdb_id}/{self.chain_id}' + resp = requests.post(resouce, verify=False) + if resp.status_code != 200: + raise Exception(f"Error calculating conservation: {resp.status_code}") + + results = resp.json()['chains'][self.chain_id] + return [(int(resid), float(score['shannon'])) for resid, score in results.items()] \ No newline at end of file diff --git a/test/graphpro/annnotations/conformation_test.py b/test/graphpro/annnotations/conformation_test.py new file mode 100644 index 0000000..d7729fd --- /dev/null +++ b/test/graphpro/annnotations/conformation_test.py @@ -0,0 +1,28 @@ +import os +import MDAnalysis as mda + +from graphpro import md_analisys +from graphpro.graphgen import ContactMap +from graphpro.annotations import ConservationScore + + +u3aw0 = mda.Universe( + os.path.dirname( + os.path.realpath(__file__)) + + '/../../testdata/4aw0.pdb') + + +def test_conservation(): + G = md_analisys(u3aw0, '4AW0').generate(ContactMap(cutoff=6, chain='A'), [ConservationScore()]) + assert len(G.nodes()) == 283 + print(G.node_attr(0)) + print(G.node_attr(1)) + print(G.node_attr(2)) + assert G.node_attr(0)['cons_shannon'] == 0.8285890860167484 + assert G.node_attr(1)['cons_shannon'] == 0.8487017361472784 + assert G.node_attr(2)['cons_shannon'] == 1.330480269847359 + +def test_conservation_encoding(): + G = md_analisys(u3aw0, '4AW0').generate(ContactMap(cutoff=6, chain='A'), [ConservationScore()]) + data = G.to_data(node_encoders=[ConservationScore()]) + assert data.x.size() == (283, 1) \ No newline at end of file diff --git a/test/graphpro/graphgen_test.py b/test/graphpro/graphgen_test.py index 049d1ac..558ffb4 100644 --- a/test/graphpro/graphgen_test.py +++ b/test/graphpro/graphgen_test.py @@ -9,7 +9,10 @@ from MDAnalysis.tests.datafiles import PDB, XTC u1 = mda.Universe(PDB, XTC) - +FIVEHTC = mda.Universe( + os.path.dirname( + os.path.realpath(__file__)) + + '/../testdata/5htc.pdb') def test_graph_generation_from_mdanalysis(): G = md_analisys(u1).generate(ContactMap(cutoff=6)) @@ -25,6 +28,11 @@ def test_graph_generation_from_mdanalysis_custom_residue(): assert(len(G.nodes()) == 5752) +def test_graph_generation_from_mdanalysis_with_chain_in_metadata(): + G = md_analisys(FIVEHTC).generate(ContactMap(cutoff=6, chain='A')) + assert(len(G.nodes()) == 328) + assert G.metadata.chain == 'A' + def test_graph_generation_collection(): graph_col = md_analisys(u1).generate_trajectory(ContactMap(cutoff=6))