diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 381509f..048dcda 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,70 +1,29 @@ ci: - autofix_commit_msg: | - [pre-commit.ci] auto fixes from pre-commit.com hooks - for more information, see https://pre-commit.ci - autofix_prs: true - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' - autoupdate_schedule: weekly - submodules: false - + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autoupdate_schedule: weekly + submodules: false repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - exclude: docs - id: check-added-large-files - args: ['--maxkb=100000'] + args: ["--maxkb=100000"] - id: end-of-file-fixer - exclude: docs - id: check-yaml args: ["--unsafe"] - - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - name: Fixes formatting - language_version: python3 - args: ["--line-length=120"] - - - - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.8 hooks: - - id: flake8 - name: Checks pep8 style - args: [ - "--max-line-length=120", - # Ignore imports in init files - "--per-file-ignores= - */__init__.py:F401,setup.py:E121,*/test_jnbs/props/euler_scripts/*.py:F401 E722 E711 E266, - ", - # ignore long comments (E501), as long lines are formatted by black - # ignore Whitespace before ':' (E203) - # ignore lambdas (E731) - # ignore Line break occurred before a binary operator (W503) - # needed to not remove * imports (for example in _all_blocks.py) - "--ignore=E501,E203,E231,E731,W503,F405", - ] - - - repo: local - hooks: - - id: jupyisort - name: Sorts ipynb imports - entry: jupytext --pipe-fmt ".py" --pipe "isort - --multi-line=3 --trailing-comma --force-grid-wrap=0 --use-parentheses --line-width=99" --sync - files: \.ipynb$ - language: python - stages: [pre-push] - - - id: jupyblack - name: Fixes ipynb format - entry: jupytext --pipe-fmt ".py" --pipe "black - --line-length=120" --sync - files: \.ipynb$ - language: python - stages: [pre-push] + - id: ruff + files: ^serenityff/|^tests/ + exclude: \.ipynb + - id: ruff-format + files: ^serenityff/|^tests/ diff --git a/dev/conda-env/test_env.yaml b/dev/conda-env/test_env.yaml index 6a6d625..87c569f 100644 --- a/dev/conda-env/test_env.yaml +++ b/dev/conda-env/test_env.yaml @@ -15,6 +15,7 @@ dependencies: - codecov - ipython - pre-commit + - ruff # Meta - conda-build # Science diff --git a/pyproject.toml b/pyproject.toml index 6a7aea3..43660cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,5 +40,19 @@ include-package-data = true [project.entry-points."openff.toolkit.plugins.handlers"] SerenityFFCharge = "serenityff.charge.utils.serenityff_charge_handler:SerenityFFChargeHandler" -[tool.black] +[tool.ruff] line-length = 120 + +[tool.ruff.lint] +fixable = ["I"] +select = [ + "E", # pycodestyle error + "F", # pyflakes + "I", # isort + "W", # pycodestyle warning +] + +[tool.ruff.format] +quote-style = "double" +line-ending = "auto" +indent-style = "space" diff --git a/serenityff/charge/dataset_preperation/MolMorganDataset.py b/serenityff/charge/dataset_preperation/MolMorganDataset.py index 1e18f2b..64a82d3 100644 --- a/serenityff/charge/dataset_preperation/MolMorganDataset.py +++ b/serenityff/charge/dataset_preperation/MolMorganDataset.py @@ -1,22 +1,22 @@ # Copyright (C) 2022-2025 ETH Zurich, Jakob Teetz and other DASH contributors. import os -import numpy import pickle +import numpy + try: from typing import Literal except ImportError: from typing_extensions import Literal -from typing import Optional from collections import defaultdict -from rdkit import Chem -from rdkit.Chem import Descriptors -from rdkit.Chem import Draw -from rdkit.Chem import AllChem +from typing import Optional + from matplotlib import lines from matplotlib import pyplot as plt +from rdkit import Chem +from rdkit.Chem import AllChem, Descriptors, Draw def shrink(KeyAppearanceDict: defaultdict, Cutoff: int) -> None: @@ -49,8 +49,8 @@ def ddict2dict(d: defaultdict) -> dict: class MolMorganDataset: """ - Storage of Morgan fingerprints of a molecule dataset (stored in sdf-file). Used to compare to other MolMorganDatasets on a fingerprint count level. And other helpful - tools for data cleaning, management and merging. + Storage of Morgan fingerprints of a molecule dataset (stored in sdf-file). Used to compare to other + MolMorganDatasets on a fingerprint count level. And other helpful tools for data cleaning, management and merging. """ def __init__( @@ -60,13 +60,18 @@ def __init__( ConsiderEquivalentKeysOnce: bool = True, FilterForChargedMolecules: Optional[Literal["atom", "molecule", "none"]] = "molecule", ) -> None: - """Initializes a new MolMorganDataset and reads in the data from a pickle file, if available, or calculates it new. + """Initializes a new MolMorganDataset and reads in the data from a pickle file, if available, or calculates it + new. Args: DataPath (str): path to sdf file - UseOldData (bool, optional): If False, data will be calculated new, even if a pickle file is available. Defaults to True. - ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. Otherwise each fingerprint is just counted once. Defaults to True. - FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to 'molecule' + UseOldData (bool, optional): If False, data will be calculated new, even if a pickle file is available. + Defaults to True. + ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. + Otherwise each fingerprint is just counted once. Defaults to True. + FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. + For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to + 'molecule' """ self._data_path = DataPath self._folder_path = os.path.split(self._data_path)[0] @@ -110,12 +115,17 @@ def ReadData( ConsiderEquivalentKeysOnce: bool = True, FilterForChargedMolecules: Optional[Literal["atom", "molecule", "none"]] = "molecule", ) -> None: - """Read in data from pickle file if present. Otherwise calls CalcData function to calculate data and store it in a pickle file. + """Read in data from pickle file if present. Otherwise calls CalcData function to calculate data and store it + in a pickle file. Args: - UseOldData (bool, optional): If False, data will be calculated new, even if a pickle file is available. Defaults to True. - ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. Otherwise each fingerprint is just counted once. Defaults to True. - FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to 'molecule' + UseOldData (bool, optional): If False, data will be calculated new, even if a pickle file is available. + Defaults to True. + ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. + Otherwise each fingerprint is just counted once. Defaults to True. + FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. + For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to + 'molecule' """ if UseOldData and os.path.isfile(self._saved_path): print("loaded stored data for " + self._filename) @@ -144,11 +154,15 @@ def CalcData( ConsiderEquivalentKeysOnce: bool = True, FilterForChargedMolecules: Optional[Literal["atom", "molecule", "none"]] = "molecule", ) -> None: - """Calculate various data for the MolMorganDataset(Appearances of Morgan fingerprints(radius 2), number of charged/uncharged/total molecules, molarweights) and write them into a pickle file + """Calculate various data for the MolMorganDataset(Appearances of Morgan fingerprints(radius 2), number of + charged/uncharged/total molecules, molarweights) and write them into a pickle file Args: - ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. Otherwise each fingerprint is just counted once. Defaults to True. - FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to 'molecule' + ConsiderEquivalentKeysOnce (bool, optional): If False, counts total appearances of fingerprint per molecule. + Otherwise each fingerprint is just counted once. Defaults to True. + FilterForChargedMolecules (Literal, optional): For 'atom' ignores all molecules that contain a charged atom. + For 'molecule' ignores all molecules with a formal charge. For 'none' takes all molecules. Default set to + 'molecule' """ prevID = 0 @@ -161,7 +175,8 @@ def CalcData( for m in self._mols: # loop over all molecules in MolMorganDataset if bool(m.HasProp("CHEMBL_ID")): if prevID == m.GetProp("CHEMBL_ID"): - self._num_mol += 1 # count amount of molecules. Conformers inlcuded so it later matches the index of molecule list + self._num_mol += 1 # count amount of molecules. Conformers inlcuded so it later matches the index + # of molecule list continue AllChem.GetMorganFingerprint(m, 2, bitInfo=fp) for val in list(fp): # loop over fingerprints of the molecule @@ -174,9 +189,9 @@ def CalcData( if fp[val][0][1] < 1: # fingerprints for radius 0 self._key_dict[0][val] += len(fp[val]) else: - self._key_dict[2][ - val - ] += 1 # counting appearances of different keys (with counting equivalent atoms as one) + self._key_dict[2][val] += ( + 1 # counting appearances of different keys (with counting equivalent atoms as one) + ) if fp[val][0][1] < 2: # fingerprints for radius 1 self._key_dict[1][val] += 1 if fp[val][0][1] < 1: # fingerprints for radius 0 @@ -208,9 +223,9 @@ def CalcData( if fp[val][0][1] < 1: # fingerprints for radius 0 self._key_dict_nocharge[0][val] += len(fp[val]) else: - self._key_dict_nocharge[2][ - val - ] += 1 # counting appearances of different keys (with counting equivalent atoms as one) + self._key_dict_nocharge[2][val] += ( + 1 # counting appearances of different keys (with counting equivalent atoms as one) + ) if fp[val][0][1] < 2: # fingerprints for radius 1 self._key_dict_nocharge[1][val] += 1 if fp[val][0][1] < 1: # fingerprints for radius 0 @@ -289,7 +304,8 @@ def weight_distribution(self, IgnoreH: bool = False) -> None: int(numpy.ceil(max(weight))) + int(numpy.ceil(max(weight))) % 2, 2, ), - ) # make bin width of 2 as it produces substructure for uneven weights otherwise. %2 terms to ensure range is dividable by 2 + ) # make bin width of 2 as it produces substructure for uneven weights otherwise. %2 terms to ensure range + # is dividable by 2 plt.xlabel("Molecular weight [u]") plt.ylabel("Counts") plt.title("Weight distribution of dataset " + self._filename) @@ -323,11 +339,13 @@ def intersection(self, otherset) -> set[str]: return intersect # return list of smiles that are present in both sets def compare(self, otherset, UseChargedMolecules: bool = False) -> None: - """Compares the appearance of morgan fingerprints in MolMorganDatasets and plots them according to the metric (#set1 - #set2)/(#set1 + #set2) + """Compares the appearance of morgan fingerprints in MolMorganDatasets and plots them according to the metric + (#set1 - #set2)/(#set1 + #set2) Args: otherset (MolMorganDataset): Used for comparison - UseChargedMolecules (bool, optional): If False, excludes fingerprints from charged molecules. Defaults to False. + UseChargedMolecules (bool, optional): If False, excludes fingerprints from charged molecules. Defaults to + False. """ print("comparing... ") self._sort = defaultdict( @@ -378,7 +396,8 @@ def compare(self, otherset, UseChargedMolecules: bool = False) -> None: key: (ownkeys.get(key, 0) - otherkeys.get(key, 0)) / added.get( key, 0 - ) # create dictionary with the calculated metric to measure the difference in the data set: 1 == only appear in new set; -1 == only in old set + ) # create dictionary with the calculated metric to measure the difference in the data set: 1 == only + # appear in new set; -1 == only in old set for key in set(otherkeys) | set(ownkeys) } t = { @@ -397,13 +416,14 @@ def compare(self, otherset, UseChargedMolecules: bool = False) -> None: key: (cnew.get(key, 0) - cold.get(key, 0)) / add.get(key, 0) for key in set(cold) | set(cnew) } c = {k: v for k, v in sorted(add_met.items(), key=lambda item: item[1], reverse=True)} # sort by value - for (k, v) in t.items(): # go through all possible keys + for k, v in t.items(): # go through all possible keys if k not in c.keys(): # when they were deleted in this iteration of shrinking .... if k not in hold.keys(): # and they are not yet in included in new... hold[k] = v # add them col[k] = i - 1 # and store the color depending on the iteration in which they were added - # This means elements that only appear once are added first and then the once that appear twice and so on --> they later keep the order - for (k, v) in t.items(): # finally add all keys that were left over + # This means elements that only appear once are added first and then the once that appear + # twice and so on --> they later keep the order + for k, v in t.items(): # finally add all keys that were left over if k not in hold.keys(): hold[k] = v col[k] = 4 @@ -559,13 +579,17 @@ def plot_compare(self, otherset): def missings( self, otherset, radius: int = 0, UseChargedMolecules: bool = False, DrawMolecules: bool = False ) -> list: - """Finds the fingerprints are missing that appear in other dataset and prints from how many molecules they origin(this number can be further reduced) + """Finds the fingerprints are missing that appear in other dataset and prints from how many molecules they + origin(this number can be further reduced) Args: otherset (Dataset): Used for comparison - radius (int, optional): Radius (maximum of 2) for which the missing fingerprints are calculated. Defaults to 0. - UseChargedMolecules (bool, optional): If False, excludes fingerprints that come from charged molecules. Defaults to False. - DrawMolecules (bool, optional): If True, prints a Grid-Image of molecules with missing fingerprints sorted by number of missing fingerprints. Defaults to False. + radius (int, optional): Radius (maximum of 2) for which the missing fingerprints are calculated. Defaults to + 1. + UseChargedMolecules (bool, optional): If False, excludes fingerprints that come from charged molecules. + Defaults to False. + DrawMolecules (bool, optional): If True, prints a Grid-Image of molecules with missing fingerprints sorted + by number of missing fingerprints. Defaults to False. Returns: List of molecules if DrawMolecules false, otherwise returns a image """ @@ -632,7 +656,8 @@ def missings( if DrawMolecules: self._missingmols = [ x for x in self._missingmols if Descriptors.ExactMolWt(x) <= 750 - ] # only use molecules smaller than 750u. If large ones are included all molecules will be scaled accordingly and are hard to see + ] # only use molecules smaller than 750u. If large ones are included all molecules will be scaled + # accordingly and are hard to see if len(self._missingmols) != 0: grid_img = Draw.MolsToGridImage( self._missingmols, molsPerRow=5, subImgSize=(500, 500), highlightAtomLists=highlights @@ -679,8 +704,9 @@ def add(self, otherset, NewSetName: str) -> None: def reduce(self, NewSetName: str, otherset: str = "none", cutoff: int = 5) -> None: """Without other set: Deletes molecules with only reduntant fingerprints, that appear more often than the cutoff - With other set: Adds molecules from other set so that all fingerprints are represented at least as often as the cutoff, if possible. - (Both use a greedy approach: Molecules that have the highest amount of desired fingerprints are added first.) + With other set: Adds molecules from other set so that all fingerprints are represented at least as often as the + cutoff, if possible. (Both use a greedy approach: Molecules that have the highest amount of desired + fingerprints are added first.) Args: NewSetName (str): Name of new file @@ -708,18 +734,18 @@ def reduce(self, NewSetName: str, otherset: str = "none", cutoff: int = 5) -> No with Chem.SDWriter(outputpath) as w: if ( otherset == "none" - ): # If no other set is given. Reduce number of molecules in set to minimum needed to have all fingerprints at least as often as the cutoff.(If enough are there) + ): # If no other set is given. Reduce number of molecules in set to minimum needed to have all fingerprints + # at least as often as the cutoff.(If enough are there) print("Removing redundants...") - for ( - mol - ) in ( + for mol in ( self._mols ): # first step: Only keep molecules that contain at least one key appearing 'cutoff' times or less. AllChem.GetMorganFingerprint(mol, 2, bitInfo=fp) for key in fp: if ( self._key_dict[2][key] < cutoff - ): # if any key in molecule appears in total less than cutoff, mark as usefull. They are needed no matter what! + ): # if any key in molecule appears in total less than cutoff, mark as usefull. They are needed + # no matter what! usefull = 1 if usefull == 1: # write usefull molecules to new file if ( @@ -755,7 +781,8 @@ def reduce(self, NewSetName: str, otherset: str = "none", cutoff: int = 5) -> No AllChem.GetMorganFingerprint(mol, 2, bitInfo=fp) for key in fp: if keydictnew[key] < cutoff: - news += 1 # count how many keys from each molecule in leftovers has, that are desired for new set(that are still under cutoff) + news += 1 # count how many keys from each molecule in leftovers has, that are desired for new + # set(that are still under cutoff) ranking[i] = news news = 0 ranking = { diff --git a/serenityff/charge/dataset_preperation/dummy_dataset/dummy_set.ipynb b/serenityff/charge/dataset_preperation/dummy_dataset/dummy_set.ipynb index 16c90d4..2eb8f5e 100644 --- a/serenityff/charge/dataset_preperation/dummy_dataset/dummy_set.ipynb +++ b/serenityff/charge/dataset_preperation/dummy_dataset/dummy_set.ipynb @@ -7,8 +7,10 @@ "outputs": [], "source": [ "import os\n", + "\n", "from rdkit import Chem\n", - "from serenityff.charge.dataset_preperation.MolMorganDataset import MolMorganDataset\n" + "\n", + "from serenityff.charge.dataset_preperation.MolMorganDataset import MolMorganDataset" ] }, { @@ -26,8 +28,8 @@ } ], "source": [ - "dummy_set1 = MolMorganDataset('./dummyset1.sdf')\n", - "dummy_set2 = MolMorganDataset('./dummyset2.sdf')" + "dummy_set1 = MolMorganDataset(\"./dummyset1.sdf\")\n", + "dummy_set2 = MolMorganDataset(\"./dummyset2.sdf\")" ] }, { @@ -85,7 +87,7 @@ } ], "source": [ - "dummy_set1.missings(dummy_set2, DrawMolecules=True, radius = 0)" + "dummy_set1.missings(dummy_set2, DrawMolecules=True, radius=0)" ] }, { @@ -103,8 +105,8 @@ } ], "source": [ - "newset = dummy_set1.add(dummy_set2, NewSetName='combined_sets')\n", - "#does automatically newset = MolMorganDataset('./combined_sets.sdf')" + "newset = dummy_set1.add(dummy_set2, NewSetName=\"combined_sets\")\n", + "# does automatically newset = MolMorganDataset('./combined_sets.sdf')" ] }, { diff --git a/serenityff/charge/dataset_preperation/set_creation.ipynb b/serenityff/charge/dataset_preperation/set_creation.ipynb index 61869aa..3e46bc8 100644 --- a/serenityff/charge/dataset_preperation/set_creation.ipynb +++ b/serenityff/charge/dataset_preperation/set_creation.ipynb @@ -33,16 +33,16 @@ "metadata": {}, "outputs": [], "source": [ - "from serenityff.charge.dataset_preperation.MolMorganDataset import MolMorganDataset\n", "import pandas as pd\n", - "\n", "from rdkit import Chem\n", "\n", - "corrected = MolMorganDataset('path/to/corrected.sdf')\n", - "leadlike = MolMorganDataset('path/to/leadlike.sdf')\n", - "solvents = MolMorganDataset('path/to/solvents.sdf')\n", - "qmugs500 = MolMorganDataset('path/to/qmugs500.sdf')\n", - "noH500 = MolMorganDataset('path/to/noH500.sdf')" + "from serenityff.charge.dataset_preperation.MolMorganDataset import MolMorganDataset\n", + "\n", + "corrected = MolMorganDataset(\"path/to/corrected.sdf\")\n", + "leadlike = MolMorganDataset(\"path/to/leadlike.sdf\")\n", + "solvents = MolMorganDataset(\"path/to/solvents.sdf\")\n", + "qmugs500 = MolMorganDataset(\"path/to/qmugs500.sdf\")\n", + "noH500 = MolMorganDataset(\"path/to/noH500.sdf\")" ] }, { @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "qreduced500 = qmugs500.reduce(NewSetName = 'qreduced500', cutoff = 5)" + "qreduced500 = qmugs500.reduce(NewSetName=\"qreduced500\", cutoff=5)" ] }, { @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "q500solvents = qreduced500.add(otherset=solvents, NewSetName='q500solvents')" + "q500solvents = qreduced500.add(otherset=solvents, NewSetName=\"q500solvents\")" ] }, { @@ -96,8 +96,8 @@ "metadata": {}, "outputs": [], "source": [ - "qcorrected = q500solvents.reduce(NewSetName='qcorrected', otherset=corrected)\n", - "qleadlike = qcorrected.reduce(NewSetName='leadlike', otherset=leadlike)" + "qcorrected = q500solvents.reduce(NewSetName=\"qcorrected\", otherset=corrected)\n", + "qleadlike = qcorrected.reduce(NewSetName=\"leadlike\", otherset=leadlike)" ] }, { @@ -114,8 +114,8 @@ "metadata": {}, "outputs": [], "source": [ - "qleadreduced = qleadlike.reduce(NewSetName='qleadreduced')\n", - "final = qleadreduced.add(NewSetName='final', otherset=solvents)\n" + "qleadreduced = qleadlike.reduce(NewSetName=\"qleadreduced\")\n", + "final = qleadreduced.add(NewSetName=\"final\", otherset=solvents)" ] }, { @@ -132,7 +132,7 @@ "metadata": {}, "outputs": [], "source": [ - "final = MolMorganDataset('path/to/final.sdf')\n", + "final = MolMorganDataset(\"path/to/final.sdf\")\n", "final.weight_distribution()" ] }, @@ -168,23 +168,25 @@ "ID_tot = []\n", "set_ID = []\n", "wrongs = []\n", - "wrongchembls = ['CHEMBL3590587',\n", - " 'CHEMBL3590586',\n", - " 'CHEMBL3590584',\n", - " 'CHEMBL3590585',\n", - " 'CHEMBL3617051',\n", - " 'CHEMBL3752539'] #got these manually\n", + "wrongchembls = [\n", + " \"CHEMBL3590587\",\n", + " \"CHEMBL3590586\",\n", + " \"CHEMBL3590584\",\n", + " \"CHEMBL3590585\",\n", + " \"CHEMBL3617051\",\n", + " \"CHEMBL3752539\",\n", + "] # got these manually\n", "\n", "for mol in final._mols:\n", " if Chem.MolToSmiles(mol) not in smiles_tot:\n", " smiles_tot.append(Chem.MolToSmiles(mol))\n", - " else: \n", - " print('redundant molecule')\n", + " else:\n", + " print(\"redundant molecule\")\n", "\n", "for mol in qmugs500._mols:\n", " if Chem.MolToSmiles(mol) not in smiles_qmugs:\n", " smiles_qmugs.append(Chem.MolToSmiles(mol))\n", - " ID_qmugs.append(mol.GetProp('CHEMBL_ID'))\n", + " ID_qmugs.append(mol.GetProp(\"CHEMBL_ID\"))\n", "\n", "for mol in corrected._mols:\n", " if Chem.MolToSmiles(mol) not in smiles_corrected:\n", @@ -197,8 +199,8 @@ "for mol in leadlike._mols:\n", " if Chem.MolToSmiles(mol) not in smiles_leadlike:\n", " smiles_leadlike.append(Chem.MolToSmiles(mol))\n", - " if mol.HasProp('chembl_id'):\n", - " ID_leadlike.append(mol.GetProp('chembl_id'))\n", + " if mol.HasProp(\"chembl_id\"):\n", + " ID_leadlike.append(mol.GetProp(\"chembl_id\"))\n", " else:\n", " ID_leadlike.append(0)\n", "\n", @@ -216,25 +218,25 @@ " set_ID.append(3)\n", " ID_tot.append(ID_leadlike[smiles_leadlike.index(sm)])\n", " else:\n", - " print(smiles_tot.index(sm), ' is missing in others')\n", + " print(smiles_tot.index(sm), \" is missing in others\")\n", " set_ID.append(10)\n", - " ID_tot.append('missing')\n", + " ID_tot.append(\"missing\")\n", " wrongs.append(smiles_tot.index(sm))\n", - " \n", + "\n", "for i, ind in enumerate(wrongs):\n", " if set_ID[ind] == 10:\n", " set_ID[ind] = 3\n", " else:\n", - " print*('mistake', ind)\n", - " if ID_tot[ind] == 'missing':\n", - " ID_tot[ind]= wrongchembls[i]\n", + " print * (\"mistake\", ind)\n", + " if ID_tot[ind] == \"missing\":\n", + " ID_tot[ind] = wrongchembls[i]\n", " else:\n", - " print('mistake2', ind)\n", + " print(\"mistake2\", ind)\n", "\n", - "print(final._num_mol) #make sure that all list are same length and no molecules are missed\n", + "print(final._num_mol) # make sure that all list are same length and no molecules are missed\n", "print(len(smiles_tot))\n", "print(len(set_ID))\n", - "print(len(ID_tot))\n" + "print(len(ID_tot))" ] }, { @@ -243,9 +245,9 @@ "metadata": {}, "outputs": [], "source": [ - "printdata = {'Smiles': smiles_tot, \"Set_ID\": set_ID, \"CHEMBL_ID\": ID_tot}\n", + "printdata = {\"Smiles\": smiles_tot, \"Set_ID\": set_ID, \"CHEMBL_ID\": ID_tot}\n", "pls = pd.DataFrame(printdata)\n", - "print(len(smiles_tot),len(set_ID), len(ID_tot))" + "print(len(smiles_tot), len(set_ID), len(ID_tot))" ] }, { @@ -254,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "pls.to_csv('final_smiles.csv', index = True)" + "pls.to_csv(\"final_smiles.csv\", index=True)" ] }, { @@ -277,8 +279,9 @@ "source": [ "for ind in wrongs:\n", " set_ID.insert(ind, 3)\n", - " ID_tot.insert(ind, )\n", - " " + " ID_tot.insert(\n", + " ind,\n", + " )" ] }, { @@ -292,7 +295,7 @@ "for i, mol in enumerate(qleadlike._mols):\n", " if Chem.MolToSmiles(mol) in wrongsmiles:\n", " qleadlikewrongs.append(i)\n", - " chemblidwrongs.append(qleadlike._mols[i].GetProp('chembl_id'))" + " chemblidwrongs.append(qleadlike._mols[i].GetProp(\"chembl_id\"))" ] }, { @@ -303,7 +306,7 @@ "source": [ "Chem.MolToSmiles(leadlike._mols[leadlikewrongs[0]])\n", "leadlike._mols[leadlikewrongs[0]]\n", - "leadlike._mols[leadlikewrongs[5]].GetProp('chembl_id')" + "leadlike._mols[leadlikewrongs[5]].GetProp(\"chembl_id\")" ] }, { @@ -313,7 +316,8 @@ "outputs": [], "source": [ "from rdkit.Chem.Draw import IPythonConsole\n", - "IPythonConsole.molSize = 450,400\n", + "\n", + "IPythonConsole.molSize = 450, 400\n", "IPythonConsole.drawOptions.addAtomIndices = True\n", "m = Chem.Mol(leadlike._mols[leadlikewrongs[0]])\n", "m.RemoveAllConformers()\n", @@ -338,7 +342,7 @@ "leadlikewrongs = []\n", "for i, mol in enumerate(leadlike._mols):\n", " try:\n", - " if mol.GetProp('chembl_id') in chemblidwrongs:\n", + " if mol.GetProp(\"chembl_id\") in chemblidwrongs:\n", " leadlikewrongs.append(i)\n", " except:\n", " continue" diff --git a/serenityff/charge/examples/01_data_set_and_datapreperation.ipynb b/serenityff/charge/examples/01_data_set_and_datapreperation.ipynb index e69de29..1f1299e 100644 --- a/serenityff/charge/examples/01_data_set_and_datapreperation.ipynb +++ b/serenityff/charge/examples/01_data_set_and_datapreperation.ipynb @@ -0,0 +1,23 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "89d3f544", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/serenityff/charge/examples/02_train_ml_model.ipynb b/serenityff/charge/examples/02_train_ml_model.ipynb index 569ad50..e2a417c 100644 --- a/serenityff/charge/examples/02_train_ml_model.ipynb +++ b/serenityff/charge/examples/02_train_ml_model.ipynb @@ -23,16 +23,16 @@ "%autoreload 2\n", "\n", "from shutil import rmtree\n", + "\n", "import pandas as pd\n", "import torch\n", "from rdkit import Chem\n", "\n", "from serenityff.charge.gnn.training.trainer import (\n", - " Trainer,\n", " ChargeCorrectedNodeWiseAttentiveFP,\n", + " Trainer,\n", ")\n", "\n", - "\n", "sdf_file = \"../data/example.sdf\"\n", "pt_file = \"../data/example_graphs.pt\"\n", "state_dict_path = \"../data/example_state_dict.pt\"\n", diff --git a/serenityff/charge/examples/03_extract_attention_weights.ipynb b/serenityff/charge/examples/03_extract_attention_weights.ipynb index b445ad0..2534b09 100644 --- a/serenityff/charge/examples/03_extract_attention_weights.ipynb +++ b/serenityff/charge/examples/03_extract_attention_weights.ipynb @@ -7,11 +7,13 @@ "outputs": [], "source": [ "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", "from rdkit import Chem\n", "from rdkit.Chem.Draw import SimilarityMaps\n", + "\n", "from serenityff.charge.gnn.attention_extraction.extractor import Extractor\n", "\n", "sdf_file = \"../data/example.sdf\"\n", @@ -80,10 +82,7 @@ } ], "source": [ - "Extractor.run_extraction_local(sdf_file=sdf_file,\n", - " ml_model=model_path,\n", - " epochs=50,\n", - " output=extraction_path)" + "Extractor.run_extraction_local(sdf_file=sdf_file, ml_model=model_path, epochs=50, output=extraction_path)" ] }, { @@ -226,8 +225,8 @@ "atom_idx_to_draw = 2\n", "mol_idx_to_draw = 0\n", "attention_data_for_atom = eval(\n", - " df[(df[\"idx_in_mol\"] == atom_idx_to_draw)\n", - " & (df[\"mol_index\"] == mol_idx_to_draw)].iloc[0][\"node_attentions\"])\n", + " df[(df[\"idx_in_mol\"] == atom_idx_to_draw) & (df[\"mol_index\"] == mol_idx_to_draw)].iloc[0][\"node_attentions\"]\n", + ")\n", "mol = Chem.SDMolSupplier(sdf_file, removeHs=False)[mol_idx_to_draw]\n", "print(attention_data_for_atom)" ] @@ -341,7 +340,8 @@ " color=\"yellow\",\n", " fill=False,\n", " linewidth=3,\n", - " ))" + " )\n", + ")" ] }, { @@ -381,7 +381,8 @@ " color=\"yellow\",\n", " fill=False,\n", " linewidth=3,\n", - " ))\n", + " )\n", + " )\n", "plot" ] }, diff --git a/serenityff/charge/examples/04_build_a_decision_tree.ipynb b/serenityff/charge/examples/04_build_a_decision_tree.ipynb index 80bf638..e4d8816 100644 --- a/serenityff/charge/examples/04_build_a_decision_tree.ipynb +++ b/serenityff/charge/examples/04_build_a_decision_tree.ipynb @@ -8,17 +8,18 @@ "source": [ "import os\n", "from shutil import rmtree\n", + "\n", "from rdkit import Chem\n", "from rdkit.Chem.Draw import IPythonConsole\n", "\n", "IPythonConsole.ipython_useSVG = True\n", "IPythonConsole.drawOptions.addAtomIndices = True\n", - "from tqdm import tqdm\n", "import numpy as np\n", "import pandas as pd\n", + "from tqdm import tqdm\n", "\n", - "from serenityff.charge.tree.dash_tree import DASHTree\n", "from serenityff.charge.tree.atom_features import AtomFeatures\n", + "from serenityff.charge.tree.dash_tree import DASHTree\n", "from serenityff.charge.tree_develop.tree_constructor import Tree_constructor" ] }, @@ -1766,9 +1767,7 @@ "mol = Chem.SDMolSupplier(sdf_suply, removeHs=False)[0]\n", "[mol.ClearProp(prop) for prop in mol.GetPropNames()]\n", "atom_feture_key = AtomFeatures.atom_features_from_molecule(mol, 0)\n", - "print(\n", - " f\"Atom with idx:{atom_idx_in_mol} has AtomFeature: {AtomFeatures.lookup_int(atom_feture_key)}\"\n", - ")\n", + "print(f\"Atom with idx:{atom_idx_in_mol} has AtomFeature: {AtomFeatures.lookup_int(atom_feture_key)}\")\n", "mol" ] }, @@ -1786,9 +1785,7 @@ } ], "source": [ - "print(\n", - " f\"In the current version of serenityff there are {AtomFeatures.get_number_of_features()} atom features\"\n", - ")" + "print(f\"In the current version of serenityff there are {AtomFeatures.get_number_of_features()} atom features\")" ] }, { @@ -2363,9 +2360,7 @@ "source": [ "for branch in example_tree.tree_storage:\n", " if len(example_tree.tree_storage[branch]) > 1:\n", - " print(\n", - " f\"Branch {branch} has {len(example_tree.tree_storage[branch])} nodes\"\n", - " )" + " print(f\"Branch {branch} has {len(example_tree.tree_storage[branch])} nodes\")" ] }, { @@ -2501,14 +2496,8 @@ "for mol_idx in mol_idx_test:\n", " try:\n", " mol = Chem.SDMolSupplier(sdf_suply, removeHs=False)[mol_idx]\n", - " example_tree_charges.extend(\n", - " example_tree.get_molecules_partial_charges(mol)[\"charges\"]\n", - " )\n", - " example_ref_charges.extend(\n", - " tree_constructor.test_df[\n", - " tree_constructor.test_df.mol_index == mol_idx\n", - " ].truth.values\n", - " )\n", + " example_tree_charges.extend(example_tree.get_molecules_partial_charges(mol)[\"charges\"])\n", + " example_ref_charges.extend(tree_constructor.test_df[tree_constructor.test_df.mol_index == mol_idx].truth.values)\n", " except:\n", " print(f\"Failed mol with index {mol_idx}\")" ] @@ -3876,9 +3865,7 @@ "ax.plot([-1, 1], [-1, 1], color=\"grey\", linestyle=\"--\")\n", "ax.set_xlabel(\"Reference charges [e]\")\n", "ax.set_ylabel(\"Tree charges [e]\")\n", - "ax.set_title(\n", - " f\"Example tree charge correlation\\n RMSE: {np.sqrt(np.mean((df.tree-df.ref)**2)):.3f} e\"\n", - ")" + "ax.set_title(f\"Example tree charge correlation\\n RMSE: {np.sqrt(np.mean((df.tree - df.ref) ** 2)):.3f} e\")" ] }, { @@ -5298,9 +5285,7 @@ ], "source": [ "ax = pd.Series(charges_in_path).fillna(0).plot.line(label=\"Charge [e]\")\n", - "ax2 = pd.Series(counts_in_path).plot.line(\n", - " secondary_y=True, ax=ax, color=\"red\", label=\"Counts in node\"\n", - ")\n", + "ax2 = pd.Series(counts_in_path).plot.line(secondary_y=True, ax=ax, color=\"red\", label=\"Counts in node\")\n", "ax.set_xlabel(\"Tree depth along path\")\n", "ax.set_ylabel(\"Charge [e]\")\n", "ax.set_title(\"Example path in tree\")\n", @@ -6390,8 +6375,7 @@ } ], "source": [ - "ax = pd.Series(attention_in_path).cumsum().fillna(0).plot.line(\n", - " label=\"Attention\")\n", + "ax = pd.Series(attention_in_path).cumsum().fillna(0).plot.line(label=\"Attention\")\n", "ax.set_xlabel(\"Tree depth along path\")\n", "ax.set_ylabel(\"Attention\")\n", "ax.set_title(\"Cumulative attention along path in the example tree\")" diff --git a/serenityff/charge/examples/05_use_sff_as_off_plugin.ipynb b/serenityff/charge/examples/05_use_sff_as_off_plugin.ipynb index c71fe42..ac2da29 100644 --- a/serenityff/charge/examples/05_use_sff_as_off_plugin.ipynb +++ b/serenityff/charge/examples/05_use_sff_as_off_plugin.ipynb @@ -15,19 +15,19 @@ ], "source": [ "import os\n", + "\n", "import pandas as pd\n", "from rdkit import Chem\n", "from rdkit.Chem import AllChem\n", "from rdkit.Chem.Draw import IPythonConsole\n", - "IPythonConsole.ipython_useSVG=True\n", - "IPythonConsole.drawOptions.addAtomIndices=True\n", "\n", - "from matplotlib import pyplot as plt\n", + "IPythonConsole.ipython_useSVG = True\n", + "IPythonConsole.drawOptions.addAtomIndices = True\n", "\n", + "from matplotlib import pyplot as plt\n", "from openff.toolkit.topology import Molecule, Topology\n", "from openff.toolkit.typing.engines.smirnoff import ForceField\n", - "\n", - "from openmm import unit, app, LangevinIntegrator" + "from openmm import LangevinIntegrator, app, unit" ] }, { @@ -1046,7 +1046,9 @@ "metadata": {}, "outputs": [], "source": [ - "simulation.reporters.append(app.StateDataReporter(\"./example_openmm.csv\", 10, step=True, potentialEnergy=True, temperature=True))" + "simulation.reporters.append(\n", + " app.StateDataReporter(\"./example_openmm.csv\", 10, step=True, potentialEnergy=True, temperature=True)\n", + ")" ] }, { diff --git a/serenityff/charge/examples/06_pattern&tree_visualization.ipynb b/serenityff/charge/examples/06_pattern&tree_visualization.ipynb index 3c27d5c..8ce1a4c 100644 --- a/serenityff/charge/examples/06_pattern&tree_visualization.ipynb +++ b/serenityff/charge/examples/06_pattern&tree_visualization.ipynb @@ -7,17 +7,17 @@ "outputs": [], "source": [ "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", "from rdkit import Chem\n", - "from rdkit.Chem.Draw import SimilarityMaps\n", - "from rdkit.Chem import Draw\n", - "from rdkit.Chem import rdDepictor\n", - "from rdkit.Chem.Draw import IPythonConsole\n", + "from rdkit.Chem import Draw, rdDepictor\n", + "from rdkit.Chem.Draw import IPythonConsole, SimilarityMaps\n", + "\n", "IPythonConsole.ipython_useSVG = True\n", "IPythonConsole.ipython_showProperties = False\n", - "IPythonConsole.drawOptions.addAtomIndices=True\n", + "IPythonConsole.drawOptions.addAtomIndices = True\n", "\n", "sdf_file = \"../data/example.sdf\"\n", "model_path = \"../data/example_model.pt\"\n", @@ -1280,9 +1280,15 @@ } ], "source": [ - "plt_att_10 = SimilarityMaps.GetSimilarityMapFromWeights(mol, normalize_attention_between_0_and_1(att_atom_10), size=(100, 100))\n", - "plt_att_0 = SimilarityMaps.GetSimilarityMapFromWeights(mol, normalize_attention_between_0_and_1(att_atom_0), size=(100, 100))\n", - "plt_att_4 = SimilarityMaps.GetSimilarityMapFromWeights(mol, normalize_attention_between_0_and_1(att_atom_4), size=(100, 100))" + "plt_att_10 = SimilarityMaps.GetSimilarityMapFromWeights(\n", + " mol, normalize_attention_between_0_and_1(att_atom_10), size=(100, 100)\n", + ")\n", + "plt_att_0 = SimilarityMaps.GetSimilarityMapFromWeights(\n", + " mol, normalize_attention_between_0_and_1(att_atom_0), size=(100, 100)\n", + ")\n", + "plt_att_4 = SimilarityMaps.GetSimilarityMapFromWeights(\n", + " mol, normalize_attention_between_0_and_1(att_atom_4), size=(100, 100)\n", + ")" ] }, { @@ -1384,7 +1390,7 @@ } ], "source": [ - "df_mf2.sort_values(by=\"count\", ascending=False).plot.bar(figsize=(5, 2), rot=0, logy=True)\n" + "df_mf2.sort_values(by=\"count\", ascending=False).plot.bar(figsize=(5, 2), rot=0, logy=True)" ] }, { @@ -1408,7 +1414,7 @@ ], "source": [ "fig, ax = plt.subplots(1, 1, figsize=(15, 5))\n", - "df_mf2.sort_values(by=\"count\", ascending=False).plot.bar(ax=ax, rot=0, logy=True, width=1.1, color=\"C2\") \n", + "df_mf2.sort_values(by=\"count\", ascending=False).plot.bar(ax=ax, rot=0, logy=True, width=1.1, color=\"C2\")\n", "ax.set_xticklabels([])\n", "ax.set_xlabel(\"morgen fingerprint r=2\")\n", "ax.set_ylabel(\"count\")\n", diff --git a/serenityff/charge/examples/07_DASH-tree.ipynb b/serenityff/charge/examples/07_DASH-tree.ipynb index 3f8585b..47edcab 100644 --- a/serenityff/charge/examples/07_DASH-tree.ipynb +++ b/serenityff/charge/examples/07_DASH-tree.ipynb @@ -22,8 +22,9 @@ "source": [ "from rdkit import Chem\n", "from tqdm import tqdm\n", - "from serenityff.charge.tree.dash_tree import DASHTree\n", - "from serenityff.charge.tree.atom_features import AtomFeatures" + "\n", + "from serenityff.charge.tree.atom_features import AtomFeatures\n", + "from serenityff.charge.tree.dash_tree import DASHTree" ] }, { diff --git a/serenityff/charge/gnn/attention_extraction/__init__.py b/serenityff/charge/gnn/attention_extraction/__init__.py index 1aaa5c7..cc3d9ab 100644 --- a/serenityff/charge/gnn/attention_extraction/__init__.py +++ b/serenityff/charge/gnn/attention_extraction/__init__.py @@ -4,5 +4,6 @@ from .extractor import Extractor __all__ = [ - Extractor, + "Extractor", + "Explainer", ] diff --git a/serenityff/charge/gnn/attention_extraction/bash_templates/__init__.py b/serenityff/charge/gnn/attention_extraction/bash_templates/__init__.py index 6977e45..acc1596 100644 --- a/serenityff/charge/gnn/attention_extraction/bash_templates/__init__.py +++ b/serenityff/charge/gnn/attention_extraction/bash_templates/__init__.py @@ -2,3 +2,9 @@ from .cleaner import CLEANER_CONTENT from .worker import get_lsf_worker_content, get_slurm_worker_content + +__all__ = [ + "get_lsf_worker_content", + "get_slurm_worker_content", + "CLEANER_CONTENT", +] diff --git a/serenityff/charge/gnn/attention_extraction/bash_templates/worker.py b/serenityff/charge/gnn/attention_extraction/bash_templates/worker.py index 2ec8cbf..22fb0e3 100644 --- a/serenityff/charge/gnn/attention_extraction/bash_templates/worker.py +++ b/serenityff/charge/gnn/attention_extraction/bash_templates/worker.py @@ -4,7 +4,8 @@ #! /bin/bash python -c "from serenityff.charge.gnn.attention_extraction.extractor import Extractor; -Extractor._extract_hpc(model='$1', sdf_index=int(${{ varname }}), scratch='$TMPDIR', sdf_property_name='$3', no_charge_correction=bool(int('$4')))" +Extractor._extract_hpc(model='$1', sdf_index=int(${{ varname }}), scratch='$TMPDIR', sdf_property_name='$3', \ + no_charge_correction=bool(int('$4')))" mv ${TMPDIR}/${{{ varname }}}.csv ${2}/. """ diff --git a/serenityff/charge/gnn/attention_extraction/explainer.py b/serenityff/charge/gnn/attention_extraction/explainer.py index 39c58cb..dcc6924 100644 --- a/serenityff/charge/gnn/attention_extraction/explainer.py +++ b/serenityff/charge/gnn/attention_extraction/explainer.py @@ -4,10 +4,10 @@ from torch import Tensor from torch.nn import Module - from torch_geometric.explain.algorithm.gnn_explainer import ( GNNExplainer_ as GNNExplainer, ) + from serenityff.charge.gnn.utils import CustomData diff --git a/serenityff/charge/gnn/attention_extraction/extractor.py b/serenityff/charge/gnn/attention_extraction/extractor.py index b21b678..6fda19f 100644 --- a/serenityff/charge/gnn/attention_extraction/extractor.py +++ b/serenityff/charge/gnn/attention_extraction/extractor.py @@ -12,6 +12,12 @@ from rdkit import Chem from tqdm import tqdm +from serenityff.charge.gnn.attention_extraction import Explainer +from serenityff.charge.gnn.attention_extraction.bash_templates import ( + CLEANER_CONTENT, + get_lsf_worker_content, + get_slurm_worker_content, +) from serenityff.charge.gnn.utils import ( ChargeCorrectedNodeWiseAttentiveFP, NodeWiseAttentiveFP, @@ -19,12 +25,6 @@ ) from serenityff.charge.utils import command_to_shell_file from serenityff.charge.utils.exceptions import ExtractionError -from serenityff.charge.gnn.attention_extraction import Explainer -from serenityff.charge.gnn.attention_extraction.bash_templates import ( - CLEANER_CONTENT, - get_lsf_worker_content, - get_slurm_worker_content, -) class Extractor: @@ -158,7 +158,7 @@ def _explain_molecules_in_sdf( "truth", ], ) - out = output if output else f'{scratch}/{sdf_file.split(".")[0].split("/")[-1] + ".csv"}' + out = output if output else f"{scratch}/{sdf_file.split('.')[0].split('/')[-1] + '.csv'}" df.to_csv( path_or_buf=out, index=False, @@ -316,7 +316,7 @@ def _parse_filenames(args: Sequence[str]) -> argparse.Namespace: > -s, --sdffile: .SDF file containing a list of molecules\ you want a prediction and attention extaction for. > -p, --property: Name of the property in the sdf to explain. Defaults to 'MBIScharge'. - > --no-charge-correction: if flag is present, the model is not physics informed (i.e. not charge corrected.). + > --no-charge-correction: if flag is present, the model is not physics informed (not charge corrected.). Returns: argparse.Namespace: Namespace containing necessary strings. @@ -558,7 +558,8 @@ def run_extraction_slurm(args: Sequence[str]) -> None: slurm_command = ( f"sbatch -n 1 --cpus-per-task=1 --time=120:00:00 --job-name='clean_up' " f"--mem-per-cpu=1024 --output='logfiles/cleanup.out' --error='logfiles/cleanup.err' " - f"--open-mode=append --dependency=afterok:{id} --wrap='./cleaner.sh {num_files} {batch_size} {files.sdffile}'" + f"--open-mode=append --dependency=afterok:{id} --wrap='./cleaner.sh " + f"{num_files} {batch_size} {files.sdffile}'" ) os.system(slurm_command) command_to_shell_file(slurm_command, "run_cleanup.sh") diff --git a/serenityff/charge/gnn/training/trainer.py b/serenityff/charge/gnn/training/trainer.py index 5fe2a31..a65446a 100644 --- a/serenityff/charge/gnn/training/trainer.py +++ b/serenityff/charge/gnn/training/trainer.py @@ -7,7 +7,6 @@ from tqdm import tqdm - try: from typing import Literal except ImportError: @@ -18,8 +17,8 @@ from serenityff.charge.gnn.utils import ( ChargeCorrectedNodeWiseAttentiveFP, - NodeWiseAttentiveFP, CustomData, + NodeWiseAttentiveFP, get_graph_from_mol, mols_from_sdf, split_data_Kfold, @@ -241,13 +240,14 @@ def gen_graphs_from_sdf( verbose: bool = verbose, sdf_property_name: str = "MBIScharge", ) -> None: - """ - Creates pytorch geometric graphs using the custom featurizer for all molecules in a sdf file. 'MolFileAlias' in the sdf is taken - as the ground truth value, generate your input sdf file accordingly. + """Creates pytorch geometric graphs using the custom featurizer for all molecules in a sdf file. + + 'MolFileAlias' in the sdf is taken as the ground truth value, generate your input sdf file accordingly. Args: sdf_file (str): path to .sdf file holding the molecules. - allowable_set (Optional[List[int]], optional): Allowable atom types. Defaults to [ "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "H", ]. + allowable_set (Optional[List[int]], optional): Allowable atom types. Defaults to + [ "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "H", ]. """ mols = mols_from_sdf(sdf_file) self.data = [ @@ -285,11 +285,13 @@ def prepare_training_data( split: Optional[int] = 0, seed: Optional[int] = None, ) -> None: - """ - Splits training data into test data and eval data. At the moment, random, kfold and smiles split are implemented. + """Splits training data into test data and eval data. + + At the moment, random, kfold and smiles split are implemented. Args: - split_type (Optional[Literal["random", "kfold"]], optional): What split type you want. Defaults to "random". + split_type (Optional[Literal["random", "kfold"]], optional): What split type you want. + Defaults to "random". train_ratio (Optional[float], optional): ratio of train/eval in random split. Defaults to 0.8. n_splits (Optional[int], optional): number of splits in the kfold split. Defaults to 5. split (Optional[int], optional): which of the n_splits you want. Defaults to 0. diff --git a/serenityff/charge/gnn/utils/attentive_fp.py b/serenityff/charge/gnn/utils/attentive_fp.py index ae0f7db..d0a7253 100644 --- a/serenityff/charge/gnn/utils/attentive_fp.py +++ b/serenityff/charge/gnn/utils/attentive_fp.py @@ -1,13 +1,12 @@ # Copyright (C) 2023-2025 ETH Zurich, Marc Lehner and other DASH contributors. -from typing import Optional, Any import math +from typing import Any, Optional import torch import torch.nn.functional as F from torch import Tensor from torch.nn import GRUCell, Linear, Parameter - from torch_geometric.nn import GATConv, MessagePassing, global_add_pool from torch_geometric.typing import Adj, OptTensor from torch_geometric.utils import softmax @@ -77,7 +76,6 @@ def message( ptr: OptTensor, size_i: Optional[int], ) -> Tensor: - x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1))) alpha_j = (x_j * self.att_l).sum(dim=-1) alpha_i = (x_i * self.att_r).sum(dim=-1) diff --git a/serenityff/charge/gnn/utils/featurizer.py b/serenityff/charge/gnn/utils/featurizer.py index ff40a19..eeb630f 100644 --- a/serenityff/charge/gnn/utils/featurizer.py +++ b/serenityff/charge/gnn/utils/featurizer.py @@ -6,7 +6,6 @@ https://github.com/deepchem/deepchem/tree/master/deepchem/feat """ - import inspect import os from typing import Any, Iterable, List, Tuple, Union @@ -558,9 +557,9 @@ def _featurize(self, datapoint: Molecule, allowable_set: List[str], **kwargs) -> graph: CustomGraphData A molecule graph with some features. """ - assert ( - datapoint.GetNumAtoms() > 1 - ), "More than one atom should be present in the molecule for this featurizer to work." + assert datapoint.GetNumAtoms() > 1, ( + "More than one atom should be present in the molecule for this featurizer to work." + ) if "mol" in kwargs: datapoint = kwargs.get("mol") raise DeprecationWarning('Mol is being phased out as a parameter, please pass "datapoint" instead.') diff --git a/serenityff/charge/gnn/utils/model.py b/serenityff/charge/gnn/utils/model.py index 5e64c7c..970ad0c 100644 --- a/serenityff/charge/gnn/utils/model.py +++ b/serenityff/charge/gnn/utils/model.py @@ -5,6 +5,7 @@ import torch from torch import nn from torch.nn import functional as F + from serenityff.charge.gnn.utils.attentive_fp import AttentiveFP diff --git a/serenityff/charge/gnn/utils/split_utils.py b/serenityff/charge/gnn/utils/split_utils.py index fe507e8..e3cedba 100644 --- a/serenityff/charge/gnn/utils/split_utils.py +++ b/serenityff/charge/gnn/utils/split_utils.py @@ -4,8 +4,7 @@ from typing import List, Optional, Sequence, Tuple import torch -from sklearn.model_selection import KFold -from sklearn.model_selection import GroupShuffleSplit +from sklearn.model_selection import GroupShuffleSplit, KFold from .custom_data import CustomData diff --git a/serenityff/charge/tests/test_fundamental.py b/serenityff/charge/tests/test_fundamental.py index 4a81442..a7b373f 100644 --- a/serenityff/charge/tests/test_fundamental.py +++ b/serenityff/charge/tests/test_fundamental.py @@ -4,6 +4,7 @@ Test most fundamental functionality of package """ + import sys import pytest diff --git a/serenityff/charge/tests/test_gnn_extraction.py b/serenityff/charge/tests/test_gnn_extraction.py index 24438ef..9d1e831 100644 --- a/serenityff/charge/tests/test_gnn_extraction.py +++ b/serenityff/charge/tests/test_gnn_extraction.py @@ -10,14 +10,13 @@ import torch from rdkit import Chem +from serenityff.charge.gnn.attention_extraction import Explainer, Extractor +from serenityff.charge.gnn.attention_extraction.explainer import GNNExplainer from serenityff.charge.gnn.utils import ( ChargeCorrectedNodeWiseAttentiveFP, + CustomData, get_graph_from_mol, ) -from serenityff.charge.gnn.attention_extraction import Extractor -from serenityff.charge.gnn.attention_extraction import Explainer -from serenityff.charge.gnn.attention_extraction.explainer import GNNExplainer -from serenityff.charge.gnn.utils import CustomData from serenityff.charge.gnn.utils.model import NodeWiseAttentiveFP from serenityff.charge.gnn.utils.rdkit_helper import mols_from_sdf from serenityff.charge.utils import Molecule, command_to_shell_file diff --git a/serenityff/charge/tests/test_gnn_training.py b/serenityff/charge/tests/test_gnn_training.py index e25177f..3d6e822 100644 --- a/serenityff/charge/tests/test_gnn_training.py +++ b/serenityff/charge/tests/test_gnn_training.py @@ -7,10 +7,9 @@ from numpy import array_equal from rdkit import Chem from torch import device, load +from torch.cuda import is_available from torch.nn.functional import mse_loss from torch.optim import Adam -from torch.cuda import is_available - from serenityff.charge.gnn.training import Trainer from serenityff.charge.gnn.utils import ( @@ -174,7 +173,6 @@ def test_train_model(trainer, sdf_path) -> None: def test_prediction(trainer, graph, molecule) -> None: - a = trainer.predict(graph) b = trainer.predict(molecule) c = trainer.predict([graph]) @@ -199,5 +197,4 @@ def test_prediction(trainer, graph, molecule) -> None: def test_on_gpu(trainer) -> None: - assert trainer._on_gpu == is_available() diff --git a/serenityff/charge/tree/atom_features.py b/serenityff/charge/tree/atom_features.py index 39d5c29..c5d3750 100644 --- a/serenityff/charge/tree/atom_features.py +++ b/serenityff/charge/tree/atom_features.py @@ -1,6 +1,7 @@ # Copyright (C) 2022-2025 ETH Zurich, Marc Lehner and other DASH contributors. from typing import Any, Tuple + import numpy as np from rdkit import Chem @@ -34,9 +35,9 @@ class AtomFeatures: > IsConjugated (True or False) > Number of Hydrogens (e.g. 0, 1, 2, ...) (includeNeighbors=True) - It has a single atom form, or a form with the connection information like the relative index of the connected atom and the bond type. - a atom with connection information can be called for example with the function atom_features_from_molecule_w_connection_info() and a - example would be: + It has a single atom form, or a form with the connection information like the relative index of the connected atom + and the bond type. a atom with connection information can be called for example with the function + atom_features_from_molecule_w_connection_info() and a example would be: [(6, 3, 0, False, 0), 0, 1] Wich is a carbon atom with 3 bonds, 0 formal charge, not aromatic and 0 hydrogens connected to an atom with index 0 via a single bond. diff --git a/serenityff/charge/tree/atom_features_reduced.py b/serenityff/charge/tree/atom_features_reduced.py index 6d67a6d..4e1b871 100644 --- a/serenityff/charge/tree/atom_features_reduced.py +++ b/serenityff/charge/tree/atom_features_reduced.py @@ -1,12 +1,14 @@ # Copyright (C) 2024-2025 ETH Zurich, Marc Lehner and other DASH contributors. +from typing import Any, Tuple + +import numpy as np +from rdkit import Chem + from serenityff.charge.tree.atom_features import ( AtomFeatures, get_connection_info_bond_type, ) -from typing import Any, Tuple -import numpy as np -from rdkit import Chem from serenityff.charge.utils import Molecule diff --git a/serenityff/charge/tree/dash_tree.py b/serenityff/charge/tree/dash_tree.py index 353b088..2624f98 100644 --- a/serenityff/charge/tree/dash_tree.py +++ b/serenityff/charge/tree/dash_tree.py @@ -762,7 +762,8 @@ def get_molecular_dipole_moment( for i in range(1, len(cids)): mol.RemoveConformer(i) # center_of_mass = np.array(ComputeCentroid(mol.GetConformer())) - # dipole_vecs = [np.array(mol.GetConformer().GetAtomPosition(i)) - center_of_mass for i in range(mol.GetNumAtoms())] + # dipole_vecs = [np.array(mol.GetConformer().GetAtomPosition(i)) - + # center_of_mass for i in range(mol.GetNumAtoms())] # vec_sum = np.sum([chg * dipole_vec for chg, dipole_vec in zip(chgs, dipole_vecs)], axis=0) if sngl_cnf: vec_sum = np.sum( @@ -910,7 +911,8 @@ def explain_property( Returns ------- Image - Image showing the molecule with the matched atoms highlighted and the contribution of each atom added to the subgraph + Image showing the molecule with the matched atoms highlighted and the contribution of each atom added to the + subgraph """ node_path, match_indices = self.match_new_atom( atom, @@ -931,7 +933,8 @@ def explain_property( property_name = "Partial charge" if prop_unit is None: prop_unit = "e" - plot_title = f"Atom: 0 ({atom} in mol). \nSum of all contributions: {property_name} = {prop_per_node[-1]: .2f} {prop_unit}" + plot_title = f"""Atom: 0 ({atom} in mol). + Sum of all contributions: {property_name} = {prop_per_node[-1]: .2f} {prop_unit}""" return draw_mol_with_highlights_in_order( mol=mol, highlight_atoms=match_indices, diff --git a/serenityff/charge/tree/node.py b/serenityff/charge/tree/node.py index 063fd19..ea0c5c1 100644 --- a/serenityff/charge/tree/node.py +++ b/serenityff/charge/tree/node.py @@ -2,11 +2,11 @@ from __future__ import annotations +from copy import copy, deepcopy from typing import Dict, List import numpy as np import pandas as pd -from copy import deepcopy, copy from serenityff.charge.tree.atom_features import AtomFeatures @@ -115,8 +115,9 @@ def add_node(self, node): self.add_child(child) def _update_statistics(self, other): - """ - helper function to merge two nodes. Updates the statistics of the current node with the statistics of the other node + """Helper function to merge two nodes. + + Updates the statistics of the current node with the statistics of the other node Parameters ---------- @@ -131,9 +132,9 @@ def _update_statistics(self, other): # had to google that one to figure out how to do this: # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups if self.count >= 3 and other.count >= 3: - std_term1 = ( - (self.stdDeviation**2 * (self.count - 1)) + (other.stdDeviation**2 * (other.count - 1)) - ) / (self.count + other.count - 1) + std_term1 = ((self.stdDeviation**2 * (self.count - 1)) + (other.stdDeviation**2 * (other.count - 1))) / ( + self.count + other.count - 1 + ) std_term2 = (self.count * other.count * (self.result - other.result) ** 2) / ( (self.count + other.count) * (self.count + other.count - 1) ) @@ -189,7 +190,7 @@ def prune(self, threshold=0.001): # child.prune(threshold) if hasattr(self, "stdDeviation") and self.stdDeviation != np.nan: - adjusted_threshold = threshold * ((self.level / 8)) + adjusted_threshold = threshold * (self.level / 8) # adjusted_threshold = threshold if self.stdDeviation < adjusted_threshold: for child in self.children: diff --git a/serenityff/charge/tree/retrieve_data.py b/serenityff/charge/tree/retrieve_data.py index b9b1da0..5fa7d4d 100644 --- a/serenityff/charge/tree/retrieve_data.py +++ b/serenityff/charge/tree/retrieve_data.py @@ -1,6 +1,7 @@ # Copyright (C) 2024-2025 ETH Zurich, Niels Maeder and other DASH contributors. """Functionality to obtain the DASH properties data from ETH research archive.""" + import zipfile from enum import Enum, auto from pathlib import Path diff --git a/serenityff/charge/tree/tree_factory.py b/serenityff/charge/tree/tree_factory.py index 6650474..6dfcd95 100644 --- a/serenityff/charge/tree/tree_factory.py +++ b/serenityff/charge/tree/tree_factory.py @@ -1,6 +1,7 @@ # Copyright (C) 2024-2025 ETH Zurich, Niels Maeder and other DASH contributors. """Factory for dash Trees with different properties loaded.""" + from serenityff.charge.data import dash_props_tree_path from serenityff.charge.tree.dash_tree import DASHTree, TreeType diff --git a/serenityff/charge/tree/tree_utils.py b/serenityff/charge/tree/tree_utils.py index c15e7d1..b83f442 100644 --- a/serenityff/charge/tree/tree_utils.py +++ b/serenityff/charge/tree/tree_utils.py @@ -6,12 +6,9 @@ import pandas as pd from numba import njit - from serenityff.charge.tree.atom_features import AtomFeatures - -from serenityff.charge.tree_develop.develop_node import DevelopNode - from serenityff.charge.tree.dash_tree import DASHTree +from serenityff.charge.tree_develop.develop_node import DevelopNode def get_possible_atom_features(mol, connected_atoms): diff --git a/serenityff/charge/tree_develop/develop_node.py b/serenityff/charge/tree_develop/develop_node.py index c30128e..978efe2 100644 --- a/serenityff/charge/tree_develop/develop_node.py +++ b/serenityff/charge/tree_develop/develop_node.py @@ -35,8 +35,16 @@ def __str__(self) -> str: return f"node --- lvl: {self.level}, Num=1" else: if self.truth_values is None: - return f"node --- lvl: {self.level}, empty node, fp={AtomFeatures.lookup_int(self.atom_features[0])} ({self.atom_features[1]}, {self.atom_features[2]})" - return f"node --- lvl: {self.level}, Num={str(len(self.truth_values))}, Mean={float(self.average):.4f}, std={np.std(self.truth_values):.4f}, fp={AtomFeatures.lookup_int(self.atom_features[0])} ({self.atom_features[1]}, {self.atom_features[2]})" + return ( + f"node --- lvl: {self.level}, empty node, fp={AtomFeatures.lookup_int(self.atom_features[0])} " + f"({self.atom_features[1]}, {self.atom_features[2]})" + ) + + return ( + f"node --- lvl: {self.level}, Num={str(len(self.truth_values))}, Mean={float(self.average):.4f}, " + f"std={np.std(self.truth_values):.4f}, fp={AtomFeatures.lookup_int(self.atom_features[0])} " + f"({self.atom_features[1]}, {self.atom_features[2]})" + ) def __hash__(self) -> int: return hash(str(self)) diff --git a/serenityff/charge/tree_develop/tree_constructor.py b/serenityff/charge/tree_develop/tree_constructor.py index 1b70710..a4332d3 100644 --- a/serenityff/charge/tree_develop/tree_constructor.py +++ b/serenityff/charge/tree_develop/tree_constructor.py @@ -1,18 +1,18 @@ # Copyright (C) 2022-2025 ETH Zurich, Marc Lehner and other DASH contributors. import datetime +import logging +import os import pickle import random -import os -import logging import time +from collections import defaultdict from typing import NoReturn import numpy as np import pandas as pd from rdkit import Chem from tqdm import tqdm -from collections import defaultdict from serenityff.charge.tree.atom_features import ( AtomFeatures, @@ -54,7 +54,8 @@ def __init__( II) sanitize the dataframe III) split the dataframe into a train and test set (randomly or by indices provided) IV) create all adjacency matrices and atom features for all molecules - V) prepare everything for the tree building (seperate functions "create_tree_level_0" and "build_tree" are needed for the actual tree building) + V) prepare everything for the tree building (seperate functions "create_tree_level_0" and "build_tree" are + needed for the actual tree building) Parameters ---------- @@ -262,7 +263,8 @@ def _raise_index_missmatch_error( self, mol_index, number_of_atoms_in_mol_df, number_of_atoms_in_mol_sdf ) -> NoReturn: print( - f"Molecule {mol_index} has {number_of_atoms_in_mol_df} atoms in df and {number_of_atoms_in_mol_sdf} atoms in sdf" + f"Molecule {mol_index} has {number_of_atoms_in_mol_df} atoms in df " + f"and {number_of_atoms_in_mol_sdf} atoms in sdf" ) print(f"shifted mol has {self.sdf_suplier[mol_index + 1].GetNumAtoms()} atoms") print("--------------------------------------------------") @@ -289,7 +291,8 @@ def _check_charge_sanity(self) -> None: self.original_df.drop(indices_to_drop, inplace=True) if self.verbose: print( - f"Number of wrong charged mols: {len(self.wrong_charged_mols_list)} of {len(self.original_df.mol_index.unique())} mols" + "Number of wrong charged mols: " + f"{len(self.wrong_charged_mols_list)} of {len(self.original_df.mol_index.unique())} mols" ) def _check_charges(self, element, charge, indices_to_drop, df_with_mol_index, mol_index) -> None: diff --git a/serenityff/charge/tree_develop/tree_constructor_parallel_worker.py b/serenityff/charge/tree_develop/tree_constructor_parallel_worker.py index 086291d..31d0647 100644 --- a/serenityff/charge/tree_develop/tree_constructor_parallel_worker.py +++ b/serenityff/charge/tree_develop/tree_constructor_parallel_worker.py @@ -2,6 +2,7 @@ import traceback from typing import List + import numpy as np import pandas as pd diff --git a/serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py b/serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py index 3c97636..c0b3bd7 100644 --- a/serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py +++ b/serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py @@ -1,11 +1,10 @@ # Copyright (C) 2023-2025 ETH Zurich, Marc Lehner and other DASH contributors. -import pickle +import argparse import os +import pickle from typing import Sequence -import argparse -from serenityff.charge.utils import command_to_shell_file from serenityff.charge.tree.tree_utils import ( get_DASH_tree_from_DEV_tree, ) @@ -13,6 +12,7 @@ from serenityff.charge.tree_develop.tree_constructor_parallel_worker import ( Tree_constructor_parallel_worker, ) +from serenityff.charge.utils import command_to_shell_file class Tree_constructor_singleJB_worker: @@ -47,11 +47,23 @@ def run_singleJB(Tree_constructor_parallel_worker_path: str, AF_idx: int) -> Non out_folder = "tree_out" if not os.path.exists(out_folder): os.makedirs(out_folder) - command = f"#SBATCH -n 1\n#SBATCH --cpus-per-task=64\n#SBATCH --time=120:00:00\n#SBATCH --job-name='t_{AF_idx}'\n#SBATCH --nodes=1\n#SBATCH --mem-per-cpu=8000\n#SBATCH --tmp=50000\n#SBATCH --output='t_{AF_idx}.out'\n#SBATCH --error='t_{AF_idx}.err'\n#SBATCH --open-mode=append" + command = f"""#SBATCH -n 1 + #SBATCH --cpus-per-task=64 + #SBATCH --time=120:00:00 + #SBATCH --job-name='t_{AF_idx}' + #SBATCH --nodes=1 + #SBATCH --mem-per-cpu=8000 + #SBATCH --tmp=50000 + #SBATCH --output='t_{AF_idx}.out' + #SBATCH --error='t_{AF_idx}.err' + #SBATCH --open-mode=append""" # copy all the files to the $TMPDIR directory command += f"cp {Tree_constructor_parallel_worker_path} $TMPDIR/{local_tree_constructor}" command += "cd $TMPDIR" - command += f"python serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py -p {local_tree_constructor} -a {AF_idx}" + command += ( + "python serenityff/charge/tree_develop/tree_constructor_singleJB_worker.py " + r"-p {local_tree_constructor} -a {AF_idx}" + ) command += f"cp {AF_idx}.pkl {sub_folder}/tree_out/" command_to_shell_file(command, f"singleJB_{AF_idx}.sh") os.system(f"sbatch < singleJB_{AF_idx}.sh") diff --git a/serenityff/charge/utils/serenityff_charge_handler.py b/serenityff/charge/utils/serenityff_charge_handler.py index 77c7e2b..520a995 100644 --- a/serenityff/charge/utils/serenityff_charge_handler.py +++ b/serenityff/charge/utils/serenityff_charge_handler.py @@ -1,8 +1,8 @@ # Copyright (C) 2022-2025 ETH Zurich, Niels Maeder and other DASH contributors. from typing import List -import numpy as np +import numpy as np from openff.toolkit.typing.engines.smirnoff import ( ElectrostaticsHandler, LibraryChargeHandler, @@ -17,7 +17,6 @@ class SerenityFFChargeHandler(_NonbondedHandler, ToolkitWrapper): - _TAGNAME = "SerenityFFCharge" _DEPENDENCIES = [ElectrostaticsHandler, LibraryChargeHandler, vdWHandler] _KWARGS = ["toolkit_registry"] @@ -73,7 +72,6 @@ def create_force(self, system, topology, **kwargs) -> None: force = super().create_force(system, topology, **kwargs) for reference_molecule in topology.reference_molecules: - for topology_molecule in topology._reference_molecule_to_topology_molecules[reference_molecule]: partial_charges = self.assign_partial_charges(topology_molecule) diff --git a/tests/_utils.py b/tests/_utils.py index 0650a7a..faa0f92 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -1,6 +1,7 @@ # Copyright (C) 2024-2025 ETH Zurich, Niels Maeder and other DASH contributors. """Utilitary functions for running the pytests.""" + import json import os from pathlib import Path diff --git a/tests/serenityff/charge/utils/test_serenityff_charge_handler.py b/tests/serenityff/charge/utils/test_serenityff_charge_handler.py index d5d3526..f998418 100644 --- a/tests/serenityff/charge/utils/test_serenityff_charge_handler.py +++ b/tests/serenityff/charge/utils/test_serenityff_charge_handler.py @@ -1,6 +1,7 @@ # Copyright (C) 2024-2025 ETH Zurich, Niels Maeder and other DASH contributors. """Test serenityff.charge.utils.serenityff_charge_handler.py.""" + from pathlib import Path import pytest