Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mldft/utils/molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from pyscf.lib.chkfile import load
from rdkit import Chem

from mldft.utils.conversions import pyscf_to_rdkit

# To avoid circular imports and make flake8 happy
if TYPE_CHECKING:
from mldft.ml.data.components.of_data import OFData
Expand Down Expand Up @@ -515,6 +517,17 @@ def geometry_to_string(mol: gto.Mole, unit: str = "Angstrom"):
)


def get_mol_view_link(mol: gto.Mole) -> str:
"""Get the MolView link for a molecule."""
try:
chem_mol = pyscf_to_rdkit(mol)
except Chem.rdchem.AtomValenceException:
return "Can't compute smiles, rdkit failed to convert the molecule."
smiles = Chem.MolToSmiles(chem_mol)
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) # simplify the smiles string
return f"https://molview.org/?smiles={smiles}"


def check_atom_types(mol: gto.Mole, atom_types: np.ndarray) -> None:
"""Check if all atoms in the molecule are of a certain type.

Expand Down
11 changes: 11 additions & 0 deletions mldft/utils/omegaconf_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,17 @@ def parse_reps_import(rep):
return parse_reps(rep)


def to_no_basis_transforms_dataset_statistics(
dataset_statistics_path: str, transformation_name: str
) -> str:
"""Converts a dataset statistics path to no basis transforms by a simple replacement."""
return dataset_statistics_path.replace(transformation_name, "no_basis_transforms")


OmegaConf.register_new_resolver(
"to_no_basis_transforms_dataset_statistics", to_no_basis_transforms_dataset_statistics
)

# getting dim of tensor rep:
OmegaConf.register_new_resolver("get_rep_dim", lambda rep: parse_reps_import(rep).dim)
OmegaConf.register_new_resolver("get_tensorrep_dim", lambda rep: parse_reps_import(f"t{rep}").dim)
Expand Down
15 changes: 8 additions & 7 deletions mldft/utils/visualize_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def plot_orbital(
plotter: pv.Plotter = None,
figsize: tuple[int, int] = None,
title: str = None,
**plot_molecule_kwargs,
) -> pv.Plotter:
"""Plot an electron orbital using pyvista. By default, the orbital is plotted as a volume.

Expand Down Expand Up @@ -430,7 +431,7 @@ def plot_orbital(

if plot_molecule:
mol = cube.mol
pl.add_mesh(**get_sticks_mesh_dict(mol))
pl.add_mesh(**get_sticks_mesh_dict(mol, **plot_molecule_kwargs))

if mode in ["isosurface", "nested_isosurfaces"]:
if isinstance(isosurface_quantile, float):
Expand All @@ -456,12 +457,12 @@ def plot_orbital(
iso_mesh["quantile"] = np.zeros(iso_mesh.n_points)
iso_mesh["opacity"] = np.zeros(iso_mesh.n_points)

for quantile, isosurface_value in zip(quantiles, isosurface_values):
mask = iso_mesh["data"] == isosurface_value

# squaring looks good with 'seismic' colormap
iso_mesh["quantile"][mask] = quantile
iso_mesh["opacity"][mask] = 1 - np.abs(quantile)
# for every point in the mesh, find the closest isosurface value
# and assign the corresponding quantile
abs_value_differences = np.abs(iso_mesh["data"][:, None] - isosurface_values[None, :])
closest_isosurface_value_indices = np.argmin(abs_value_differences, axis=1)
iso_mesh["quantile"] = quantiles[closest_isosurface_value_indices]
iso_mesh["opacity"] = 1 - np.abs(iso_mesh["quantile"])

if mode == "nested_isosurfaces":
# add a set of nested, transparent isosurfaces
Expand Down
133 changes: 133 additions & 0 deletions notebooks/tutorial/tutorial_0_overview.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9f1be328",
"metadata": {},
"source": [
"# Tutorial 0: Overview of the OF-DFT codebase of the SCIAI-Lab\n",
"\n",
"### What our codebase is all about \n",
"This tutorial is designed to offer an introduction to our large and in some places quite complicated codebase. Our codebase serves the goal to enable machine-learned orbital free density function theory (OF-DFT). The idea is summarized well in the following figure (see https://pubs.acs.org/doi/full/10.1021/jacs.5c06219 for details):\n",
"\n",
"<img src=\"tutorial_fig1.png\" alt=\"Figure from https://pubs.acs.org/doi/full/10.1021/jacs.5c06219\" style=\"width: 90%\">\n",
"\n",
"In very simple words: we take the molecules from different public dataset (e.g. QM9 or QMugs) and compute the energy that these molecule have for different electron densities (= disribution of the electrons around the atom nuclei). For each molecule, the densities are described by a linear superposition of atom-cendered basis functions (functions of 3D space localized around the different atoms in the molecule). \n",
"\n",
"From the computation of the energies we also get gradients that tell us how to change the density in order to decrease the energy. Especially on large molecules, these computations, traditionally done with the so called Kohn-Sham DFT, are very expensive. Therefore, our goal is to train a neural network that can reproduce the energie and gradients for the electron densities in our datasets (trained on KS-DFT data) and hopefully generalizes to larger molecules.\n",
"\n",
"Our trained neural network can be used to follow the gradients to lower and lower energies to finally obtain a good estimate of the ground state electron density. We call this process density optimization. \n",
"\n",
"For a broader overview into what OF-DFT is and what we as a group are doing consider watching the following lecture video [part 1](https://www.youtube.com/watch?v=CoZUTMjU8C8) and [part 2](https://www.youtube.com/watch?v=iyx1C4vaP7k). \n",
"\n",
"### What this tutorial covers\n",
"This tutorial covers a broad range of topics from the loading and handling of the molecule dataset (which contains samples that combine an electron density with a target energy and target energy gradient) over visualizing molecules and electron densities all the way to understanding how our machine learning model is trained and can be used for density optimization. \n",
"\n",
"In summary, this tutorial will guide you through the following topics:\n",
"1) [**datamodule**](./tutorial_1_datamodule.ipynb): Understanding the following classes that handle the loading and processing of our data: OFDataset, OFData, OFBatch, OFLoader, OFDataModule, BasisInfo\n",
"\n",
"2) [**visualization**](./tutorial_2_visualization.ipynb): Demonstration of how to visualize molecules and electron densities in 3D via molview.org and via Pyvista\n",
"\n",
"3) [**transforms**](./tutorial_3_transforms.ipynb): Demystifying the MasterTransformation class that handles the transformation of data samples before they can be passed to the model for training (including a visualization of the basis transforms). Keywords: Global symmetric natrep, local frames, gradient projection.\n",
"\n",
"4) [**hydra & omegaconf**](./tutorial_4_hydra_omegaconf.ipynb): Understanding how configs are managed with hydra and omegaconf, including hydra overrides, omegaconf syntax and omegaconf resolver.\n",
"\n",
"5) [**mldftlitmodule**](./tutorial_5_mldftlitmodel.ipynb): Understaning how to train a model based on our PytorchLightning class MLDFTLitModule, including a closer look into the forward, training_step, backpropagate and the usage of dataset statistics in our model. \n",
"\n",
"6) [**density optimization**](./tutorial_6_density_optimization.ipynb): Demonstration of how to load a trained model from a checkpoint and perform density optimization, including important plots to evaluate the density optimization process. "
]
},
{
"cell_type": "markdown",
"id": "eb36cabb",
"metadata": {},
"source": [
"Before starting the tutorial, please take a look at the [README](../../README.md) to set up your virtual environment."
]
},
{
"cell_type": "markdown",
"id": "252daebd",
"metadata": {},
"source": [
"Please execute the following cell to download two small sample datasets from our huggingface model repository (https://huggingface.co/sciai-lab/structures25/tree/main). One dataset contains QM9 molecule and one the larger QMugs molecules. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "217bcba1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# download a small dataset from huggingface that contains QM9 and QMugs data\n",
"# https://huggingface.co/docs/datasets/cache#cache-directory\n",
"# The default cache directory is `~/.cache/huggingface/datasets`\n",
"# You can change it by setting this variable to any path you like\n",
"CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n",
"\n",
"\n",
"# https://huggingface.co/sciai-lab/structures25/tree/main\n",
"print(\"Downloading minimal dataset for QM9 and QMugs from huggingface...\")\n",
"os.environ[\n",
" \"HF_HUB_DISABLE_PROGRESS_BARS\"\n",
"] = \"1\" # to avoid problems with the progress bar in some environments\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"data_path = snapshot_download(\n",
" repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n",
")\n",
"\n",
"print(f\"Successfully downloaded data to the following path {data_path}.\")"
]
},
{
"cell_type": "markdown",
"id": "19cd7ec3",
"metadata": {},
"source": [
"If you want to train models properly beyond this tutorial (on the IWR servers), please set the environment variables `DFT_DATA` and `DFT_MODELS` in your `.bashrc` or `.zshrc` file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c62ac86c",
"metadata": {},
"outputs": [],
"source": [
"# overall hint: to be able to go through this tutorial and execute all cells,\n",
"# you should be able to access our data and models folder\n",
"# and should have set the two environment variables DFT_DATA and DFT_MODELS:\n",
"\n",
"print(\"DFT_DATA:\", os.getenv(\"DFT_DATA\")) # set is to /export/scratch/ialgroup/dft_data\n",
"print(\n",
" \"DFT_MODELS:\", os.getenv(\"DFT_MODELS\")\n",
") # set this one to where you want to save your trained models"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mldft",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading