diff --git a/mldft/utils/molecules.py b/mldft/utils/molecules.py index 09083b5..489ab1e 100644 --- a/mldft/utils/molecules.py +++ b/mldft/utils/molecules.py @@ -17,6 +17,8 @@ from pyscf.lib.chkfile import load from rdkit import Chem +from mldft.utils.conversions import pyscf_to_rdkit + # To avoid circular imports and make flake8 happy if TYPE_CHECKING: from mldft.ml.data.components.of_data import OFData @@ -515,6 +517,17 @@ def geometry_to_string(mol: gto.Mole, unit: str = "Angstrom"): ) +def get_mol_view_link(mol: gto.Mole) -> str: + """Get the MolView link for a molecule.""" + try: + chem_mol = pyscf_to_rdkit(mol) + except Chem.rdchem.AtomValenceException: + return "Can't compute smiles, rdkit failed to convert the molecule." + smiles = Chem.MolToSmiles(chem_mol) + smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) # simplify the smiles string + return f"https://molview.org/?smiles={smiles}" + + def check_atom_types(mol: gto.Mole, atom_types: np.ndarray) -> None: """Check if all atoms in the molecule are of a certain type. diff --git a/mldft/utils/omegaconf_resolvers.py b/mldft/utils/omegaconf_resolvers.py index 5970fb7..1435cde 100644 --- a/mldft/utils/omegaconf_resolvers.py +++ b/mldft/utils/omegaconf_resolvers.py @@ -261,6 +261,17 @@ def parse_reps_import(rep): return parse_reps(rep) +def to_no_basis_transforms_dataset_statistics( + dataset_statistics_path: str, transformation_name: str +) -> str: + """Converts a dataset statistics path to no basis transforms by a simple replacement.""" + return dataset_statistics_path.replace(transformation_name, "no_basis_transforms") + + +OmegaConf.register_new_resolver( + "to_no_basis_transforms_dataset_statistics", to_no_basis_transforms_dataset_statistics +) + # getting dim of tensor rep: OmegaConf.register_new_resolver("get_rep_dim", lambda rep: parse_reps_import(rep).dim) OmegaConf.register_new_resolver("get_tensorrep_dim", lambda rep: parse_reps_import(f"t{rep}").dim) diff --git a/mldft/utils/visualize_3d.py b/mldft/utils/visualize_3d.py index d2210db..296e006 100644 --- a/mldft/utils/visualize_3d.py +++ b/mldft/utils/visualize_3d.py @@ -366,6 +366,7 @@ def plot_orbital( plotter: pv.Plotter = None, figsize: tuple[int, int] = None, title: str = None, + **plot_molecule_kwargs, ) -> pv.Plotter: """Plot an electron orbital using pyvista. By default, the orbital is plotted as a volume. @@ -430,7 +431,7 @@ def plot_orbital( if plot_molecule: mol = cube.mol - pl.add_mesh(**get_sticks_mesh_dict(mol)) + pl.add_mesh(**get_sticks_mesh_dict(mol, **plot_molecule_kwargs)) if mode in ["isosurface", "nested_isosurfaces"]: if isinstance(isosurface_quantile, float): @@ -456,12 +457,12 @@ def plot_orbital( iso_mesh["quantile"] = np.zeros(iso_mesh.n_points) iso_mesh["opacity"] = np.zeros(iso_mesh.n_points) - for quantile, isosurface_value in zip(quantiles, isosurface_values): - mask = iso_mesh["data"] == isosurface_value - - # squaring looks good with 'seismic' colormap - iso_mesh["quantile"][mask] = quantile - iso_mesh["opacity"][mask] = 1 - np.abs(quantile) + # for every point in the mesh, find the closest isosurface value + # and assign the corresponding quantile + abs_value_differences = np.abs(iso_mesh["data"][:, None] - isosurface_values[None, :]) + closest_isosurface_value_indices = np.argmin(abs_value_differences, axis=1) + iso_mesh["quantile"] = quantiles[closest_isosurface_value_indices] + iso_mesh["opacity"] = 1 - np.abs(iso_mesh["quantile"]) if mode == "nested_isosurfaces": # add a set of nested, transparent isosurfaces diff --git a/notebooks/tutorial/tutorial_0_overview.ipynb b/notebooks/tutorial/tutorial_0_overview.ipynb new file mode 100644 index 0000000..93f62d4 --- /dev/null +++ b/notebooks/tutorial/tutorial_0_overview.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 0: Overview of the OF-DFT codebase of the SCIAI-Lab\n", + "\n", + "### What our codebase is all about \n", + "This tutorial is designed to offer an introduction to our large and in some places quite complicated codebase. Our codebase serves the goal to enable machine-learned orbital free density function theory (OF-DFT). The idea is summarized well in the following figure (see https://pubs.acs.org/doi/full/10.1021/jacs.5c06219 for details):\n", + "\n", + "\"Figure\n", + "\n", + "In very simple words: we take the molecules from different public dataset (e.g. QM9 or QMugs) and compute the energy that these molecule have for different electron densities (= disribution of the electrons around the atom nuclei). For each molecule, the densities are described by a linear superposition of atom-cendered basis functions (functions of 3D space localized around the different atoms in the molecule). \n", + "\n", + "From the computation of the energies we also get gradients that tell us how to change the density in order to decrease the energy. Especially on large molecules, these computations, traditionally done with the so called Kohn-Sham DFT, are very expensive. Therefore, our goal is to train a neural network that can reproduce the energie and gradients for the electron densities in our datasets (trained on KS-DFT data) and hopefully generalizes to larger molecules.\n", + "\n", + "Our trained neural network can be used to follow the gradients to lower and lower energies to finally obtain a good estimate of the ground state electron density. We call this process density optimization. \n", + "\n", + "For a broader overview into what OF-DFT is and what we as a group are doing consider watching the following lecture video [part 1](https://www.youtube.com/watch?v=CoZUTMjU8C8) and [part 2](https://www.youtube.com/watch?v=iyx1C4vaP7k). \n", + "\n", + "### What this tutorial covers\n", + "This tutorial covers a broad range of topics from the loading and handling of the molecule dataset (which contains samples that combine an electron density with a target energy and target energy gradient) over visualizing molecules and electron densities all the way to understanding how our machine learning model is trained and can be used for density optimization. \n", + "\n", + "In summary, this tutorial will guide you through the following topics:\n", + "1) [**datamodule**](./tutorial_1_datamodule.ipynb): Understanding the following classes that handle the loading and processing of our data: OFDataset, OFData, OFBatch, OFLoader, OFDataModule, BasisInfo\n", + "\n", + "2) [**visualization**](./tutorial_2_visualization.ipynb): Demonstration of how to visualize molecules and electron densities in 3D via molview.org and via Pyvista\n", + "\n", + "3) [**transforms**](./tutorial_3_transforms.ipynb): Demystifying the MasterTransformation class that handles the transformation of data samples before they can be passed to the model for training (including a visualization of the basis transforms). Keywords: Global symmetric natrep, local frames, gradient projection.\n", + "\n", + "4) [**hydra & omegaconf**](./tutorial_4_hydra_omegaconf.ipynb): Understanding how configs are managed with hydra and omegaconf, including hydra overrides, omegaconf syntax and omegaconf resolver.\n", + "\n", + "5) [**mldftlitmodule**](./tutorial_5_mldftlitmodel.ipynb): Understaning how to train a model based on our PytorchLightning class MLDFTLitModule, including a closer look into the forward, training_step, backpropagate and the usage of dataset statistics in our model. \n", + "\n", + "6) [**density optimization**](./tutorial_6_density_optimization.ipynb): Demonstration of how to load a trained model from a checkpoint and perform density optimization, including important plots to evaluate the density optimization process. " + ] + }, + { + "cell_type": "markdown", + "id": "eb36cabb", + "metadata": {}, + "source": [ + "Before starting the tutorial, please take a look at the [README](../../README.md) to set up your virtual environment." + ] + }, + { + "cell_type": "markdown", + "id": "252daebd", + "metadata": {}, + "source": [ + "Please execute the following cell to download two small sample datasets from our huggingface model repository (https://huggingface.co/sciai-lab/structures25/tree/main). One dataset contains QM9 molecule and one the larger QMugs molecules. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "217bcba1", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "print(\"Downloading minimal dataset for QM9 and QMugs from huggingface...\")\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "print(f\"Successfully downloaded data to the following path {data_path}.\")" + ] + }, + { + "cell_type": "markdown", + "id": "19cd7ec3", + "metadata": {}, + "source": [ + "If you want to train models properly beyond this tutorial (on the IWR servers), please set the environment variables `DFT_DATA` and `DFT_MODELS` in your `.bashrc` or `.zshrc` file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c62ac86c", + "metadata": {}, + "outputs": [], + "source": [ + "# overall hint: to be able to go through this tutorial and execute all cells,\n", + "# you should be able to access our data and models folder\n", + "# and should have set the two environment variables DFT_DATA and DFT_MODELS:\n", + "\n", + "print(\"DFT_DATA:\", os.getenv(\"DFT_DATA\")) # set is to /export/scratch/ialgroup/dft_data\n", + "print(\n", + " \"DFT_MODELS:\", os.getenv(\"DFT_MODELS\")\n", + ") # set this one to where you want to save your trained models" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_1_datamodule.ipynb b/notebooks/tutorial/tutorial_1_datamodule.ipynb new file mode 100644 index 0000000..5505aef --- /dev/null +++ b/notebooks/tutorial/tutorial_1_datamodule.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 1: Understanding our OFDatamodule that handles all data samples during training\n", + "\n", + "The goal of this tutorial is to understand the behaviour and interplay of following classes that handle the loading and processing of our data: OFDataset, OFData, OFBatch, OFLoader, OFDataModule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import rich\n", + "import torch\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files like \"get_len\" to get lengths of lists\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4d41ea1b", + "metadata": {}, + "source": [ + "## 1 Loading the datamodule from our gigantic config" + ] + }, + { + "cell_type": "markdown", + "id": "30ac0051", + "metadata": {}, + "source": [ + "First, we load a large config as Omegaconf Dict config for training a model\n", + "with the defaut settings for data, optimizer, transforms, basis set, etc.\n", + "For now you can think of the config as a large nested dictionary that contains all settings\n", + "and hyperparameters used for training our OF-DFT model.\n", + "Later in the tutorial ([tutorial_4_hydra_omega_conf](./tutorial_4_hydra_omegaconf.ipynb)), we will go into more detail about how this works." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a6c66c", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf.dictconfig import DictConfig\n", + "\n", + "# the following initialize already handles the communication and combination\n", + "# of the different config files, e.g. for data and the model\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "\n", + "# let us take a look at the part of the config data is used specifically for configuring the data module\n", + "rich.print(dict_to_tree(config.data.datamodule, guide_style=\"dim\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb4203c1", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.datamodule import OFDataModule\n", + "\n", + "# we will now use this part of the config to instantiate important individual parts\n", + "# of the full training pipeline e.g. the datamodule\n", + "datamodule = instantiate(config.data.datamodule)\n", + "datamodule.batch_size = 4 # set batch size to 4 (relatively small) for demonstration purposes\n", + "print(\"Successfully instantiated datamodule:\", type(datamodule))\n", + "datamodule.setup(stage=\"fit\") # prepare the data, e.g. split into train, val, test\n", + "# with stage=\"fit\" no test set is prepared, only the train and validation set used during training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eddb61c3", + "metadata": {}, + "outputs": [], + "source": [ + "# to get a quick look of what is combined in the datamodule, we can look at its __dict__\n", + "datamodule.__dict__" + ] + }, + { + "cell_type": "markdown", + "id": "b15aa52a", + "metadata": {}, + "source": [ + "## 2 A first look at the dataset and a single sample" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c60ad939", + "metadata": {}, + "outputs": [], + "source": [ + "# In some place like this import the respective class so that you can click to definition\n", + "from mldft.ml.data.components.dataset import OFDataset\n", + "\n", + "# let's look at the dataset(s):\n", + "# print the length ot the train and validation set used during training:\n", + "# so-called \"split files\" are handling the split into disjoint train, val and test set\n", + "print(f\"Length of train set: {len(datamodule.train_set)}\")\n", + "print(f\"Length of val set: {len(datamodule.val_set)}\")\n", + "print(f\"type of train set: {type(datamodule.train_set)}\")\n", + "print(f\"isinstance of OFDataset {isinstance(datamodule.train_set, OFDataset)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ab17d8", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.of_data import OFData\n", + "\n", + "# get a single sample for the train set\n", + "sample = datamodule.train_set[0]\n", + "print(\"Atom positions:\", sample.pos)\n", + "print(\"Atom types:\", sample.atomic_numbers)\n", + "print(\"number of coefficients:\", sample.coeffs.shape)\n", + "print(\n", + " \"Integrals of basis functions used to describe the density:\", sample.dual_basis_integrals.shape\n", + ")\n", + "print(\"scf_iteration:\", sample.scf_iteration)\n", + "print(\"Energy label (kinetic energy + XC energy):\", sample.energy_label)\n", + "print(\"Energy key:\", datamodule.train_set.of_data_kwargs[\"energy_key\"])\n", + "print(\"Is sample an instance of OFData?\", isinstance(sample, OFData))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7e42ee8", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.utils.molecules import build_molecule_ofdata\n", + "\n", + "# need basis info to build a pySCF molecule object\n", + "# see below for more details on basis_info\n", + "basis_info = instantiate(config.data.basis_info)\n", + "\n", + "# build a pySCF molecule object from the OFData sample\n", + "mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)\n", + "print(f\"type : {type(mol)}, xyz string of that molecule:\\n\")\n", + "print(mol.tostring(\"xyz\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e213e5e9", + "metadata": {}, + "outputs": [], + "source": [ + "from pyscf.lib import param\n", + "\n", + "bohr2ang = param.BOHR # approx 0.529177 , i.e. 1 bohr = 0.529177 Angstrom\n", + "# our dataset works in the \"distance\" unit Bohr but others (like RDKit in this case) work in Angstrom\n", + "# to see how both are consistent we can convert the positions\n", + "print(\"Positions in Angstrom:\\n\", sample.pos * bohr2ang)\n", + "\n", + "# note that from the pyscf.Mole object we can also get the atom positions in different units via:\n", + "print(\"Positions in Angstrom from pyscf.Mole:\\n\", mol.atom_coords(unit=\"Angstrom\"))\n", + "print(\"Positions in Bohr from pyscf.Mole:\\n\", mol.atom_coords(unit=\"Bohr\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6261df2", + "metadata": {}, + "outputs": [], + "source": [ + "from rdkit.Chem import Draw\n", + "\n", + "from mldft.utils.conversions import pyscf_to_rdkit\n", + "\n", + "# please, note that the transformation from a set of atom positions (e.g. xyzfile) to an rdkit molecule\n", + "# with bonds (and nice pictures/structure as below) is not necessarily well defined,\n", + "# since it is non-trivial to infer chemical bonds from just positions and atom types\n", + "# (though this should not be an issue for classic QM9 and QMUGS molecules)\n", + "\n", + "rdkit_mol = pyscf_to_rdkit(mol)\n", + "print(\"type\", type(rdkit_mol))\n", + "# show the molecule with rdkit\n", + "img = Draw.MolToImage(rdkit_mol)\n", + "plt.imshow(img)\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "f0e7cf40", + "metadata": {}, + "source": [ + "## 3 Our density representation and the BasisInfo class" + ] + }, + { + "cell_type": "markdown", + "id": "f123d26e", + "metadata": {}, + "source": [ + "We represent the electron density $\\rho(\\vec r)$, which is a function of 3D space, as a linear combination of so-called atom-cendered basis functions (each is a function of 3D space localized around a different atoms in the molecule).\n", + "$$\\rho(\\vec r) = \\sum_\\mu p_\\mu \\omega_\\mu(\\vec r)$$\n", + "$p_\\mu$ are the density coefficients and $\\omega_\\mu(\\vec r)$ are the different basis functions. We use Gaussian type orbitals (GTOs) as basis functions which combine a Gaussian-like radial part with a spherical harmonic angular part. Please take a look at the [STRUCTURES25 paper](https://pubs.acs.org/doi/10.1021/jacs.5c06219) for more details. " + ] + }, + { + "cell_type": "markdown", + "id": "de80f53b", + "metadata": {}, + "source": [ + "Above, we have seen a visulization of a electron density using the coefficients and the basis functions.\n", + "let us now look at the basis info object in more detail:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed221186", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.basis_info import BasisInfo\n", + "\n", + "# the essential info about all basis functions for the different atom types is stored in\n", + "# basis_info.basis_dict, a dictionary with the following structure:\n", + "# key: atom type val: list of (angular momentum, [exponent, weighting coeffs for contractions])\n", + "# see https://pyscf.org/user/gto.html#basis-format for details\n", + "basis_info.basis_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d758121b", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"atomic numbers in the dataset:\", basis_info.atomic_numbers)\n", + "print(\"Number of basis functions/coeffs per atom type:\", basis_info.basis_dim_per_atom)\n", + "\n", + "# for instance, we can take a look at the integrals of the basis functions for Hydrogen:\n", + "# all basis functions that have l>0 integrate to zero:\n", + "basis_info.integrals[0]" + ] + }, + { + "cell_type": "markdown", + "id": "9089e111", + "metadata": {}, + "source": [ + "## 4 Our dataloader converts OFData into OFBatch objects\n", + "We are gradually moving towards training a model. For that, we take a look at the dataloaders that combine multiple molecules into batches, which are then passed to the model for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5564b2cb", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.loader import OFLoader\n", + "from mldft.ml.data.components.of_batch import OFBatch\n", + "\n", + "datamodule.batch_size = 4 # set batch size to 4 (relatively small) for demonstration purposes\n", + "train_loader = datamodule.train_dataloader()\n", + "for batch in train_loader:\n", + " # get the first batch in the train loader as the model would\n", + " batch\n", + " break\n", + "\n", + "# an alternative to get the first batch from the train_loader is the following:\n", + "# batch = next(iter(train_loader))\n", + "\n", + "# one special thing about geometric graph data:\n", + "# different molecules have different number of atoms, therefore combining them into\n", + "# one batch is not as simple as stacking them into a tensor\n", + "# but it is more an appending into one large graph with all atoms and the\n", + "# batch.batch tensor indicating which atom belongs to which molecule in the large graph\n", + "print(\"number of molecules in the batch:\", batch.num_graphs)\n", + "print(\"Number of atoms in the batch:\", batch.num_nodes)\n", + "print(\"batch.batch:\", batch.batch, \"len(batch.batch):\", len(batch.batch))\n", + "print(\"Length of 'concatenated' atom positions:\", batch.pos.shape)\n", + "print(\"Length of 'concatenated' atomic numbers:\", batch.atomic_numbers.shape)\n", + "\n", + "# find out how many atoms are in each molecule in the batch\n", + "num_atoms_per_mol = torch.bincount(batch.batch)\n", + "print(\"Number of atoms per molecule in the batch:\", num_atoms_per_mol)\n", + "# average, max, min, number of atoms in the molecules in the batch\n", + "print(\n", + " \"average number of atoms per molecule in the batch:\", num_atoms_per_mol.float().mean().item()\n", + ")\n", + "print(\"max number of atoms per molecule in the batch:\", num_atoms_per_mol.max().item())\n", + "print(\"min number of atoms per molecule in the batch:\", num_atoms_per_mol.min().item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2b1587b", + "metadata": {}, + "outputs": [], + "source": [ + "# a batch can be separated into individual data samples (molecules) again via:\n", + "list_of_molecules = batch.to_data_list()\n", + "print(\n", + " \"Length of list_of_molecules:\",\n", + " len(list_of_molecules),\n", + " \"first mol in list is:\",\n", + " list_of_molecules[0],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "484e47e1", + "metadata": {}, + "source": [ + "It is also possible to create batches manually from a list of OFData samples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c79531e4", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.of_data import Representation\n", + "\n", + "# let us add a new property to each sample\n", + "# (this is a bit special since we always specify the representation of items, see sample.representations)\n", + "for molecule in list_of_molecules:\n", + " molecule.add_item(\n", + " key=\"example_property\", value=torch.tensor(42.0), representation=Representation.SCALAR\n", + " )\n", + "\n", + "batch = OFBatch.from_data_list(list_of_molecules)\n", + "batch.example_property" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_2_visualization.ipynb b/notebooks/tutorial/tutorial_2_visualization.ipynb new file mode 100644 index 0000000..728f0db --- /dev/null +++ b/notebooks/tutorial/tutorial_2_visualization.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 2: 3D Visualization of molecules, electron densities and basis functions\n", + "\n", + "The goal of this tutorial is to demonstrate how to visualize molecules and electron densities in 3D via molview.org and via Pyvista." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files like \"get_len\" to get lengths of lists\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "from mldft.utils.molecules import build_molecule_ofdata\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a6c66c", + "metadata": {}, + "outputs": [], + "source": [ + "# first we load our large config, instantiate the datamodule and obtain a single sample\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule = instantiate(config.data.datamodule)\n", + "datamodule.setup(stage=\"fit\")\n", + "sample = datamodule.train_set[0]\n", + "\n", + "# need basis info to build a pySCF molecule object\n", + "# see below for more details on basis_info\n", + "basis_info = instantiate(config.data.basis_info)\n", + "\n", + "# build a pySCF molecule object from the OFData sample\n", + "mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "7bfa54f4", + "metadata": {}, + "source": [ + "## 3D visualization based on [molview.org](molview.org)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddae44cf", + "metadata": {}, + "outputs": [], + "source": [ + "# A simple way to visualize molecules in 3D is via molview.org.\n", + "# Note though that the displayed geometry is inferred from the SMILES string, so it does not exactly correspond to the geometry in the sample object.\n", + "# if you click on the link you will see 3D structure of the molecule in the browser:\n", + "from mldft.utils.molecules import get_mol_view_link\n", + "\n", + "print(\"Click on the following link to visualize the molecule in 3D in your browser:\")\n", + "get_mol_view_link(mol)" + ] + }, + { + "cell_type": "markdown", + "id": "456cbf8b", + "metadata": {}, + "source": [ + "## 3D visualization in the notebook based on pyvista " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff2eacf3", + "metadata": {}, + "outputs": [], + "source": [ + "# --- before importing pyvista (or anything that pulls in trame) ---\n", + "import sys\n", + "\n", + "_orig_argv = sys.argv[:] # keep a copy\n", + "print(\"Original sys.argv:\", sys.argv)\n", + "sys.argv = [sys.argv[0]] + [a for a in sys.argv[1:] if not a.startswith(\"--f=\")]\n", + "\n", + "import pyvista # <-- trame/pyvista import happens here\n", + "\n", + "from mldft.utils.visualize_3d import (\n", + " get_local_frames_mesh_dict,\n", + " get_sticks_mesh_dict,\n", + " visualize_orbital,\n", + ")\n", + "\n", + "# this give a ball and stick model of the molecule\n", + "molecule_mesh = get_sticks_mesh_dict(mol)\n", + "\n", + "# this can be used to visualize local frames (in this case just the global coordinate frame at the origin)\n", + "global_frame_mesh = get_local_frames_mesh_dict(\n", + " origins=torch.zeros(1, 3),\n", + " bases=torch.eye(3)[None],\n", + " scale=2,\n", + ")\n", + "\n", + "# plot the molecule and the global frame using pyvista:\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl.add_mesh(**global_frame_mesh)\n", + "pl.add_mesh(**molecule_mesh)\n", + "pl.reset_camera(\n", + " bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()\n", + ")\n", + "\n", + "print(\n", + " \"3d visualization of our sample molecule together with the global coordinate frame placed at the origin:\"\n", + ")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))\n", + "\n", + "# the following can also be used to for programmatric plotting of 3d molecules in matplotlib:\n", + "print(\"\\n\\nWe can also obtain a non-interactive image of the molecule:\")\n", + "plt.imshow(img)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "0a9e26f4", + "metadata": {}, + "source": [ + "For instance, we can use the screenshot function to create a small gif of the camera rotating around the molecule:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "054752c8", + "metadata": {}, + "outputs": [], + "source": [ + "# create a series of images in which the camera rotates around the molecule\n", + "\n", + "\n", + "def rotate_around_molecule(pl, n_frames=36, radius=1.5):\n", + " angles = np.linspace(0, 2 * np.pi, n_frames, endpoint=False)\n", + " images = []\n", + " for angle in angles:\n", + " camera_position = [\n", + " radius * np.cos(angle),\n", + " radius * np.sin(angle),\n", + " 0.5 * radius,\n", + " ]\n", + " pl.camera_position = (camera_position, (0, 0, 0), (0, 0, 1))\n", + " pl.render() # <- force update\n", + " img = pl.screenshot(transparent_background=False, window_size=(800, 400))\n", + " images.append(img)\n", + " return images\n", + "\n", + "\n", + "# create the images\n", + "images = rotate_around_molecule(pl, n_frames=60, radius=26.0)\n", + "\n", + "# create a gif\n", + "import imageio\n", + "\n", + "imageio.mimsave(\"molecule_rotation.gif\", images, fps=30, loop=0)\n", + "\n", + "# display the gif\n", + "from IPython.display import Image\n", + "\n", + "Image(filename=\"molecule_rotation.gif\") # this will display the gif in the notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e72a9ea9", + "metadata": {}, + "outputs": [], + "source": [ + "# let us visualize the electron density:\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl = visualize_orbital(\n", + " mol=mol,\n", + " coeff=sample.coeffs.numpy(),\n", + " plotter=pl,\n", + ")\n", + "pl.reset_camera(\n", + " bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()\n", + ")\n", + "\n", + "print(\"Hint: you might have to zoom in or out a bit to see some thing at first.\")\n", + "print(\n", + " \"3d visualization of the electron density as linear combination of basis functions\\nusing the coefficients in the sample:\"\n", + ")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f69aa5d5", + "metadata": {}, + "outputs": [], + "source": [ + "# let us visualize a single basis function:\n", + "coeffs = np.zeros(sample.coeffs.shape)\n", + "coeffs[194] = 1.0 # set one coefficient to one, all others to zero\n", + "\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl = visualize_orbital(\n", + " mol=mol,\n", + " coeff=coeffs,\n", + " plotter=pl,\n", + " mode=\"isosurface\",\n", + " resolution=0.15,\n", + " isosurface_quantile=0.95,\n", + ")\n", + "\n", + "print(\"3d visualization of a single basis function (one coefficient set to 1, all others to 0):\")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b1e72be", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_3_transforms.ipynb b/notebooks/tutorial/tutorial_3_transforms.ipynb new file mode 100644 index 0000000..da77210 --- /dev/null +++ b/notebooks/tutorial/tutorial_3_transforms.ipynb @@ -0,0 +1,707 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 3: Demystifying our data transformations -- Mastering the MasterTransformation\n", + "\n", + "The goal of this tutorial is to understand the MasterTransformation class that handles the transformation of data samples before they can be passed to the model for training. The tutorial also includes a visualization of the most important transforms the so-called basis transforms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import rich\n", + "import torch\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files like \"get_len\" to get lengths of lists\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree\n", + "from mldft.utils.molecules import build_molecule_ofdata\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a6c66c", + "metadata": {}, + "outputs": [], + "source": [ + "# first, we load our large config, instantiate the datamodule and obtain a single sample\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule = instantiate(config.data.datamodule)\n", + "datamodule.setup(stage=\"fit\")\n", + "sample = datamodule.train_set[0]\n", + "\n", + "# need basis info to build a pySCF molecule object\n", + "# see below for more details on basis_info\n", + "basis_info = instantiate(config.data.basis_info)\n", + "\n", + "# build a pySCF molecule object from the OFData sample\n", + "mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b324b5cf", + "metadata": {}, + "outputs": [], + "source": [ + "# one important but (slightly) tricky topic are the transforms that are applied\n", + "# to a data sample when loaded from the dataset\n", + "rich.print(dict_to_tree(config.data.transforms, guide_style=\"dim\"))" + ] + }, + { + "cell_type": "markdown", + "id": "7b97efb0", + "metadata": {}, + "source": [ + "## 1 The MasterTransformation class" + ] + }, + { + "cell_type": "markdown", + "id": "dbf612d4", + "metadata": {}, + "source": [ + "Some of our data transformations are quite expensive and should therefore not be performed on the fly during training.\n", + "As a solution, we have precomputed several different transformed versions of our datasets and saved them to the file servers (we call this cached data).\n", + "In this tutorial, two transforms have been applied previously to the data and are loaded as \"cached\":\n", + "* the transformation into local frames (local reference frames at every atom)\n", + "* and the global symmetric natural reparametrization (natrep),\n", + "that is, an orthonormalization of the basis functions" + ] + }, + { + "cell_type": "markdown", + "id": "2d9259ba", + "metadata": {}, + "source": [ + "### Cached data, basis transforms, pre- and post transforms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac38f8e3", + "metadata": {}, + "outputs": [], + "source": [ + "# all transforms are combined into one class called the MasterTransformation\n", + "from mldft.ml.data.components.basis_transforms import MasterTransformation\n", + "\n", + "datamodule.transforms.__dict__\n", + "print(\"Name of (cached) transforms:\", datamodule.transforms.name)\n", + "print(\"Whether to use cached data:\", datamodule.transforms.use_cached_data, \"\\n\")\n", + "print(\"cached_basis_transforms:\", datamodule.transforms.cached_basis_transforms, \"\\n\")\n", + "# these transforms must therefore not be applied on the fly during training if cached data is used\n", + "# however if use_cached_data=FALSE, these transforms are actually still applied on the fly" + ] + }, + { + "cell_type": "markdown", + "id": "02c03922", + "metadata": {}, + "source": [ + "Global symmetric natrep is the reason that the basis function integrals are no longer zero for a majority of the basis functions (l>0 prior to natrep, cf. [Tutorial 1](./tutorial_1_datamodule.ipynb)).\n", + "In fact, natrep performs a change of basis to a new set of basis functions that are **orthonormal** on a global level." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b0e5843", + "metadata": {}, + "outputs": [], + "source": [ + "# basis function integrals are no longer zero:\n", + "# side note: just ignore the word \"dual\" in the \"dual_basis_integrals\" attribute name\n", + "print(\"Basis function integrals after natrep (first 10):\", sample.dual_basis_integrals[:10])" + ] + }, + { + "cell_type": "markdown", + "id": "ed472e9b", + "metadata": {}, + "source": [ + "#### Global symmetric natrep\n", + "By diagonalizing the overlap matrix $O_{\\mu\\nu}$ of the basis functions $\\omega_\\mu$,\n", + "$$\n", + "O_{\\mu\\nu} = \\int \\mathrm d^3r \\ \\omega_\\mu \\omega_\\nu \n", + "$$,\n", + "we find a change of basis that can be used to make all basis functions mutually orthogonal. Furthermore, we can normalize the resulting basis functions so that the overlap matrix in the transformed basis becomes the identity matrix." + ] + }, + { + "cell_type": "markdown", + "id": "381e04cb", + "metadata": {}, + "source": [ + "All other transforms are applied on the fly during training\n", + "but there are several different types of such transforms:\n", + "\n", + "First, the **pre_transforms** are directly applied to the OFData sample when it is loaded from the disk.\n", + "All molecules (transformed or not) are saved as individual zarr files on the disk.\n", + "\n", + "In the default case the pre_transforms are\n", + "* ToTorch: to convert numpy arrays to torch tensors\n", + "* ProjectGradient: to project the gradient label orthogonally to the direction in which the number of electrons changes\n", + "(this is important since we want to keep the number of electrons constant during density optimization)\n", + "* AddFullEdgeIndex: to add a list of edges of a fully connected graph to the sample data\n", + "(used for message passing neural networks)" + ] + }, + { + "cell_type": "markdown", + "id": "4566189e", + "metadata": {}, + "source": [ + "#### Gradient projection\n", + "The gradient projection ensures that a step in the direction of the gradient will not change the number of electrons.\n", + "\n", + "Let $\\mathrm w_\\mu$ be the integral of the basis functions $\\omega(\\vec r)$:\n", + "$$\n", + "\\mathrm w_\\mu = \\int \\mathrm d^3 \\ \\vec r \\omega_\\mu(\\vec r)\n", + "$$ \n", + "The number of electrons for a given density is then:\n", + "$$\n", + "N_e = \\int \\mathrm d^3 \\ \\vec r \\sum_\\mu p_\\mu \\omega_\\mu(\\vec r) = \\sum_\\mu p_\\mu \\mathrm w_\\mu = \\mathbf p^T \\mathbf w\n", + "$$.\n", + "If we collect all $\\mathrm w_\\mu$ in a vector $\\mathbf w$, then the projection operator that acts on the gradients is given by\n", + "$$\n", + "\\Pi = I - \\frac{\\mathbf w \\mathbf w^T}{\\mathbf w^T\\mathbf w}\n", + "$$. One can easily check that indeed $\\Pi \\Pi = \\Pi$. If we now consider an arbitrary change to our density $p \\to p' = p + \\Delta p$ the number of electrons of the density will change. But for $p \\to p' = p + \\Pi \\Delta p$ is stays constant:\n", + "$$\n", + "N_e' = (\\mathbf p + \\Pi \\Delta \\mathbf p)^T \\mathbf w = (\\mathbf p^T + \\Delta \\mathbf p^T \\Pi^T) \\mathbf w = N_e + \\Delta \\mathbf p^T \\Big(I - \\frac{\\mathbf w \\mathbf w^T}{\\mathbf w^T\\mathbf w} \\Big) \\mathbf w = N_e + \\Delta \\mathbf p^T \\Big(\\mathbf w - \\mathbf w \\frac{\\mathbf w^T\\mathbf w}{\\mathbf w^T\\mathbf w}\\Big) = N_e\n", + "$$.\n", + "Indeed, $\\Pi$ is a projection operator that when applied to the gradient step ($\\Delta \\mathbf p = \\text{learning rate} \\times \\nabla_p E$) preserves the number of electrons $N_e$ of the corresponding electron density." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbd598b3", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"pre_transforms:\", datamodule.transforms.pre_transforms)" + ] + }, + { + "cell_type": "markdown", + "id": "23cd3b4d", + "metadata": {}, + "source": [ + "Second, **additional_pre_transforms**: In contrast to pre_transforms,\n", + "additional_pre_transforms are only used if NOT cached data is used\n", + "therefore, in our case even though it is specified in the config the\n", + "AddOverlapMatrix transform is not applied.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffdc8e90", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"additional_pre_transforms:\", datamodule.transforms.additional_pre_transforms)" + ] + }, + { + "cell_type": "markdown", + "id": "d9417ce4", + "metadata": {}, + "source": [ + "Third, the **basis_transforms** are always applied." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ace12c0", + "metadata": {}, + "outputs": [], + "source": [ + "# since we use cached data, we do not apply any basis transforms on the fly\n", + "print(\"basis_transforms:\", datamodule.transforms.basis_transforms, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "c85cb7ae", + "metadata": {}, + "source": [ + "Fourth, the **post_transforms**: These are also always applied\n", + "and typically prepare the data for the model,\n", + "e.g. one can change the dtype here between float32 and float64.\n", + "In the default case, we use ToTorch to make sure that all attributes in OFData are converted to torch tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a85639e3", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"post_transforms:\", datamodule.transforms.post_transforms)" + ] + }, + { + "cell_type": "markdown", + "id": "bdbf0e8f", + "metadata": {}, + "source": [ + "**The reason for our complicated transform structure is the following:** \n", + "Basis transforms, such as the local frames transforms or the natrep transformation, affect the basis functions (see below). Therefore, for consistency, basis transforms transform *all* fields in the sample according to their geometric representation. Thus, the pre_transforms are important to potentially add attributes to the data samples which should then be affected by the basis transforms, e.g. the AddOverlapMatrix transform." + ] + }, + { + "cell_type": "markdown", + "id": "b779ee4e", + "metadata": {}, + "source": [ + "Next, let us manually change the `use_cached_data` to False such that\n", + "AddOverlapMatrix transform will actually be applied.\n", + "In that case, the data will be loaded as untransformed data\n", + "and then the LocalFrames and SymmetricGlobalNatrep basis transforms will be applied on the fly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "604c7f83", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "# let us initialize a second datamodule but with use_cached_data=False\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config_no_cache = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " \"data.transforms.use_cached_data=False\", # override to not use cached data\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config_no_cache.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule_no_cache = instantiate(config_no_cache.data.datamodule)\n", + "datamodule_no_cache.setup(stage=\"fit\")\n", + "\n", + "t0 = time.time()\n", + "sample_cache = datamodule.train_set[0]\n", + "print(\n", + " f\"Sample has overlap_matrix: {hasattr(sample_cache, 'overlap_matrix')}, since pre_transforms are not used.\"\n", + ")\n", + "print(f\"Loading sample with cached transforms took: {time.time()-t0:.2f} seconds\")\n", + "t0 = time.time()\n", + "sample_no_cache = datamodule_no_cache.train_set[0]\n", + "print(\n", + " f\"Sample has overlap_matrix: {hasattr(sample_no_cache, 'overlap_matrix')}, since pre_transforms are used.\"\n", + ")\n", + "print(f\"Loading sample without cached transforms took: {time.time()-t0:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd5072ce", + "metadata": {}, + "outputs": [], + "source": [ + "# We can confirm that after symmetric global natrep,\n", + "# the overlap matrix of the basis functions is indeed close to the identity matrix:\n", + "print(\n", + " \"Overlap matrix close to identity\",\n", + " torch.allclose(sample_no_cache.overlap_matrix, torch.eye(sample_no_cache.coeffs.shape[0])),\n", + ")\n", + "\n", + "# let's confirm that otherwise the two samples are identical\n", + "for key in [\"pos\", \"coeffs\", \"ground_state_coeffs\", \"dual_basis_integrals\"]:\n", + " print(\n", + " f\"Checking that {key} are close:\", torch.allclose(sample_cache[key], sample_no_cache[key])\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "f771e790", + "metadata": {}, + "source": [ + "## 2 Visualization of local frames transformation\n", + "In our project (when using the Graphformer architecture as in STRUCTURES25), we use local frames to canonicalize the geometric input data to achieve rotational equivariance.\n", + "\n", + "Therefore, let us visualize the local frames (computed base on nearest neighbor positions):\n", + "Note that the x-axis (the green arrow) of the local frames is not visible as it always points towards the nearest neighbor atom and is therefore \"swallowed\" by the bond between these two atoms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "647e5340", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# keep only the program name so downstream parsers don't see Jupyter's -f=...\n", + "sys.argv = sys.argv[:1]\n", + "\n", + "import pyvista\n", + "\n", + "# let use explictily calculate local frames for the given sample:\n", + "from mldft.ml.models.components.local_frames_module import (\n", + " LocalFramesTransformMatrixDense,\n", + ")\n", + "from mldft.utils.visualize_3d import (\n", + " get_local_frames_mesh_dict,\n", + " get_sticks_mesh_dict,\n", + " visualize_orbital,\n", + ")\n", + "\n", + "# predict the local frames from the atomic positions and atom types:\n", + "local_frames_module = LocalFramesTransformMatrixDense()\n", + "transformation_matrix, lframes = local_frames_module.sample_forward(sample, return_lframes=True)\n", + "\n", + "local_frames_mesh = get_local_frames_mesh_dict(\n", + " origins=sample.pos,\n", + " bases=lframes,\n", + " scale=2,\n", + " # axes_radius_scale=0.06\n", + ")\n", + "\n", + "# this gives a ball and stick model of the molecule\n", + "molecule_mesh = get_sticks_mesh_dict(mol)\n", + "molecule_mesh[\"opacity\"] = 1\n", + "\n", + "# plot the molecule and the global frame using pyvista:\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl.add_mesh(**local_frames_mesh)\n", + "pl.add_mesh(**molecule_mesh)\n", + "pl.enable_shadows()\n", + "pl.reset_camera(\n", + " bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()\n", + ")\n", + "\n", + "print(\"3d visualization of local coordinate frames at every atom:\")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))" + ] + }, + { + "cell_type": "markdown", + "id": "55708b8e", + "metadata": {}, + "source": [ + "Let us illustrate the effect of the local frames transformation at a single basis function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a36e39c", + "metadata": {}, + "outputs": [], + "source": [ + "# let us visualize a single basis function:\n", + "basis_function_idx = 370\n", + "node_idx = sample.coeff_ind_to_node_ind[basis_function_idx]\n", + "coeffs = np.zeros(sample.coeffs.shape)\n", + "coeffs[basis_function_idx] = 1.0 # set one coefficient to one, all others to zero\n", + "\n", + "# this can be used to visualize local frames\n", + "# (in this case the global coordinate frame at the position at the atom)\n", + "global_frame_mesh = get_local_frames_mesh_dict(\n", + " origins=sample.pos[node_idx].view(1, 3),\n", + " # origins=torch.zeros(1, 3),\n", + " bases=torch.eye(3)[None],\n", + " scale=2,\n", + ")\n", + "\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl = visualize_orbital(\n", + " mol=mol,\n", + " coeff=coeffs,\n", + " plotter=pl,\n", + " mode=\"isosurface\",\n", + " resolution=0.15,\n", + " isosurface_quantile=0.6,\n", + ")\n", + "\n", + "pl.add_mesh(**global_frame_mesh)\n", + "print(\"P-Orbital like basis function without transforms (and global frame):\")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a02adc07", + "metadata": {}, + "source": [ + "#### How to visualize transformed basis functions (some cool theory)\n", + "\n", + "The basis functions are taken from the BasisInfo object (see above where it is initialized, and [Tutorial1](./tutorial_1_datamodule.ipynb) for more details). Short recap: The basis functions $\\omega_\\mu(\\vec r)$ are used to represent the electron density $\\rho(\\vec r)$ via linear combination:\n", + "$$\n", + "\\rho(\\vec r) = \\sum_\\mu p_\\mu \\omega_\\mu(\\vec r) = \\mathbf p^T \\boldsymbol \\omega(\\vec r),\n", + "$$\n", + "where we have grouped the coefficients $p_\\mu$ and basis functions $\\omega_\\mu$ in $d$-dimensional vectors, i.e. $\\mathbf p, \\boldsymbol \\omega(\\vec r) \\in \\mathbb R^d$.\n", + "\n", + "A basis transformation changes the basis function $\\omega_\\mu$ into new $\\omega'_\\mu$ that are a linear combination of the $\\omega_\\mu$. Similarly, the coefficients are transformed into $p'_\\mu$ that are a linear combination of the $p_\\mu$. Let $A \\in \\mathrm{GL}(d)$ be a real $d \\times d$ basis transformation matrix. Under this transformation, we *demand* that the coefficients transform like vectors, i.e. $p'_\\mu = \\sum_\\nu A_{\\mu \\nu} p_\\nu$ or in short $\\mathbf p' = A \\mathbf p$. \n", + "\n", + "Now, we ask the following question: How must the transformed $\\omega'_\\mu$ look like so that the density function stays the same, i.e. $\\rho'(\\vec r) = \\mathbf p'^T \\boldsymbol \\omega'(\\vec r) = \\mathbf p^T \\boldsymbol \\omega(\\vec r) = \\rho$. The anser is the following\n", + "$$\n", + "\\boldsymbol \\omega' = \\big(A^{-1}\\big)^T \\boldsymbol \\omega \\ \\text{ , since then: } \\ \\rho' = \\mathbf p'^T \\boldsymbol \\omega' = (A \\mathbf p)^T \\big(A^{-1}\\big)^T \\boldsymbol \\omega = \\mathbf p^T \\big(A^{-1} A\\big)^T \\boldsymbol \\omega = \\rho .\n", + "$$\n", + "\n", + "In components, this transformation behavior reads $\\omega'_\\mu = \\big[ \\big(A^{-1}\\big)^T\\big]_{\\mu \\nu} \\omega_\\nu = \\omega_\\nu \\big(A^{-1}\\big)_{\\nu \\mu}$\n", + "Thus, when interpreting $\\boldsymbol \\omega$ as row vector, it transforms like $\\boldsymbol \\omega'^T = \\boldsymbol \\omega^T A^{-1}$, that is, $\\boldsymbol \\omega$ transforms as dual vector (with the inverse of $A$ from the right), as can be seen in the `transform_tensor` function in [basis_transforms.py](../../mldft/ml/data/components/basis_transforms.py). \n", + "\n", + "\n", + "Lastly, we want to answer the following question: How can we look at the transformed basis functions $\\boldsymbol \\omega'$ without actually chaning the basis function but by changing the coefficients? For that, we consider the special \"density\" defined by\n", + "$$\n", + "\\omega'_\\sigma(\\vec r) = \\big (p^{(\\sigma)} \\big )^T \\boldsymbol \\omega'(\\vec r) , \\ \\text{ with } \\ p^{(\\sigma)}_\\mu = \\begin{cases} 1 \\ \\text{ if } \\ \\mu = \\sigma \\\\\n", + "0 \\ \\text{ else } \\end{cases} .\n", + "$$\n", + "Now, based on the above considerations, we know how find the appropriate coefficients to visualize this density in the original untransformed basis, namely:\n", + "$$\n", + "\\omega'_\\sigma(\\vec r) = \\big ( \\mathbf p^{(\\sigma)} \\big )^T \\boldsymbol \\omega'(\\vec r) = \\big (\\mathbf p^{(\\sigma)} \\big )^T \\Big( \\big(A^{-1}\\big)^T \\boldsymbol \\omega \\Big) = \\big (A^{-1} \\mathbf p^{(\\sigma)} \\big )^T \\boldsymbol \\omega\n", + "$$\n", + "So, we conclude that we can effectively visualize the transformed basis function $\\omega'_\\sigma(\\vec r)$ in untransformed basis by using the following coefficients:\n", + "$$\n", + "\\mathbf p'^{(\\sigma)} = A^{-1} \\mathbf p^{(\\sigma)} .\n", + "$$ \n", + "This is exactly what we will do below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3298f06b", + "metadata": {}, + "outputs": [], + "source": [ + "# now we apply a transform to the coeffs to see how the basis function changes:\n", + "from mldft.ml.data.components.basis_transforms import transform_tensor\n", + "from mldft.ml.data.components.of_data import Representation\n", + "\n", + "# actually we transform the coeffs with the inverse to see how the basis function will change:\n", + "# (see explanation above)\n", + "transformed_coeffs = transform_tensor(\n", + " tensor=torch.from_numpy(coeffs).float(),\n", + " transformation_matrix=transformation_matrix.T, # the transpose is the inverse for Wigner-D matrices\n", + " inv_transformation_matrix=transformation_matrix, # the inverse of the inverse is the original matrix\n", + " representation=Representation.VECTOR, # ensures multiplication with A^{-1} from the left (see above)\n", + ")\n", + "\n", + "# this can be used to visualize local frames\n", + "global_frame_mesh = get_local_frames_mesh_dict(\n", + " origins=sample.pos[node_idx].view(1, 3), # at the position of the atom\n", + " bases=lframes[node_idx].view(1, 3, 3), # use local frame instead of global frame now\n", + " scale=2.5,\n", + " axes_radius_scale=0.06,\n", + ")\n", + "\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl = visualize_orbital(\n", + " mol=mol,\n", + " coeff=transformed_coeffs,\n", + " plotter=pl,\n", + " mode=\"isosurface\",\n", + " resolution=0.15,\n", + " isosurface_quantile=0.6,\n", + " bond_radius=0.1,\n", + ")\n", + "\n", + "pl.add_mesh(**global_frame_mesh)\n", + "print(\"P-Orbital like basis function after local frames transformation (and local frame):\")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))" + ] + }, + { + "cell_type": "markdown", + "id": "632e994a", + "metadata": {}, + "source": [ + "For more information, on irreducible representations, equivariance with respect to rotations and the Wigner-D matrices consider watching the following lecture video [part 1](https://www.youtube.com/watch?v=gbEaHqJA9vI) and [part 2](https://www.youtube.com/watch?v=1-Z50VmIf9s)." + ] + }, + { + "cell_type": "markdown", + "id": "97162a3e", + "metadata": {}, + "source": [ + "Indeed, we can see that the basis function is transformed. First the blue part of the handles points in the direction of the red axis of the **global** reference frame (first visulalization), and after the transform the blue part of the handles points in the direction of the red axis of the **local** reference frame." + ] + }, + { + "cell_type": "markdown", + "id": "5dbc4b06", + "metadata": {}, + "source": [ + "## 3 Visualization of NatRep transformation\n" + ] + }, + { + "cell_type": "markdown", + "id": "aa44136d", + "metadata": {}, + "source": [ + "As a next step, let us visualize the effect of global symmetric natrep\n", + "together with the local frames transform on the same basis function that we visualized above.\n", + "\n", + "\n", + "The effect of global symmetric natrep is that the basis function is now a linear combination of all basis functions where the coefficients of that linear combination are obtained from the basis change that diagonalizes the overlap matrix. \n", + "\n", + "\n", + "Therefore the basis function will be more delocalized but in particular the *symmetric* version of global natrep ensures that the overlap of the old and the new basis functions is maximized under the following metric:\n", + "$$\n", + "\\left \\| u - v \\right \\|^2 = \\int \\vert u(\\vec r) - v(\\vec r) \\vert^2 \\mathrm \\ \\mathrm d^3\\vec r \n", + "$$\n", + "Thus, the old and the new basis function are still fairly similar.\n", + "In doing so, the *symmetric* natrep ensures that the new basis functions are still fairly localized." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30089d06", + "metadata": {}, + "outputs": [], + "source": [ + "# the following will ensure that the basis transforms transformation matrix is added to the sample:\n", + "datamodule_no_cache.transforms.add_transformation_matrix = True\n", + "sample_with_trafo = datamodule_no_cache.train_set[0]\n", + "\n", + "# now we apply a transform to the coeffs to see how the basis function changes:\n", + "\n", + "# actually we transform the coeffs with the inverse to see how the basis function will change:\n", + "# (see explanation above)\n", + "transformed_coeffs2 = transform_tensor(\n", + " tensor=torch.from_numpy(coeffs).float(),\n", + " transformation_matrix=sample_with_trafo.inv_transformation_matrix, # use the inverse transformation matrix\n", + " inv_transformation_matrix=sample_with_trafo.transformation_matrix, # the inverse of the inverse is the original matrix\n", + " representation=Representation.VECTOR, # ensures multiplication with A^{-1} from the left (see above)\n", + ")\n", + "\n", + "# this can be used to visualize local frames\n", + "# (in this case just the global coordinate frame at the origin)\n", + "global_frame_mesh = get_local_frames_mesh_dict(\n", + " origins=sample.pos[node_idx].view(1, 3),\n", + " bases=lframes[node_idx].view(1, 3, 3), # use local frame instead of global frame now\n", + " scale=2.5,\n", + " axes_radius_scale=0.06,\n", + ")\n", + "\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "isosurface_quantile = 0.9\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl = visualize_orbital(\n", + " mol=mol,\n", + " coeff=transformed_coeffs2,\n", + " plotter=pl,\n", + " mode=\"isosurface\",\n", + " resolution=0.15,\n", + " isosurface_quantile=isosurface_quantile,\n", + " bond_radius=0.1,\n", + ")\n", + "\n", + "pl.add_mesh(**global_frame_mesh)\n", + "\n", + "print(\n", + " \"P-Orbital like basis function after local frames transformation and global symmetric natrep (and local frame):\"\n", + ")\n", + "print(\n", + " f\"Visualized with isosurface_quantile={isosurface_quantile} (play around to see different iso surfaces).\"\n", + ")\n", + "img = pl.show(screenshot=True, window_size=(800, 400))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_4_hydra_omegaconf.ipynb b/notebooks/tutorial/tutorial_4_hydra_omegaconf.ipynb new file mode 100644 index 0000000..d0c7673 --- /dev/null +++ b/notebooks/tutorial/tutorial_4_hydra_omegaconf.ipynb @@ -0,0 +1,361 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 4: Hydra, OmegaConf, Overrides\n", + "\n", + "In this tutorial, you will learn more about the underlying structure which includes config files, the hydra structure and the OmegaConf magic. We will also cover the topic of overrides which allows us to access the none-default settings." + ] + }, + { + "cell_type": "markdown", + "id": "98a0e451", + "metadata": {}, + "source": [ + "## 0 Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files. For example, \"get_len\" is used to get lengths of lists.\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b03e2fca", + "metadata": {}, + "source": [ + "# 1 Hydra\n", + "\n", + "Hydra is used for configuration management. It can be thought of as a tree of configuration files. \n", + "Usually a parent config file is used to set global variables and to specify which other child config files to use. \n", + "The child config files then set specific variables that are usually related to a specific topic (e.g. model architecture, training parameters, data parameters, etc.).\n", + "\n", + "To understand how the child config files are impemented, it is recommended to take a look a the OmegaConf magic in the next chapter." + ] + }, + { + "cell_type": "markdown", + "id": "f9e12598", + "metadata": {}, + "source": [ + "# 2 OmegaConf Magic\n", + "\n", + "First, we look at and example config and see how the tree structure looks like. \n", + "\n", + "Taking a closer look at the config file, we can see that a OmegaConf resolver is used to get the length of the list of hidden layers.\n", + "In an additional example in the Apendix 1, we show how you can create your own custom resolvers.\n", + "\n", + "\n", + "We will use this config later to instantiate a model and a dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd83bd1", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "\n", + "# understanding the omegaconf config magic and syntax:\n", + "example_config = {\n", + " \"sub_dict\": {\"a\": 1, \"b\": 2, \"l\": [1, 2, 3]},\n", + " \"len_of_l\": \"${get_len:${sub_dict.l}}\", # this uses the custom resolver \"get_len\" to get the length of list l\n", + " # use this structure to cross-reference within a config ${sub_dict.l}\n", + " \"mlp\": {\n", + " \"_target_\": \"mldft.ml.models.components.mlp.MLP\", # this is used by hydra to instantiate an object of the given class\n", + " \"in_channels\": 3,\n", + " \"hidden_channels\": [16, 16, 1],\n", + " },\n", + "}\n", + "\n", + "omegaconf_example_config = OmegaConf.create(example_config)\n", + "# if you print the config naively, it shows just the strings\n", + "print(\"OmegaConf config:\", omegaconf_example_config)\n", + "# BUT if you access the value, it resolves the string using the custom resolver\n", + "print(\"Value from accessing len_of_l\", omegaconf_example_config.len_of_l) # prints 3\n", + "\n", + "# instantiate the MLP based on the example config above:\n", + "# when calling instantiate, the _target_ field is used to find the class\n", + "# and all other fields are passed as arguments to the class constructor:\n", + "# in this case the MLP class from mldft.ml.models.components.mlp\n", + "# with in_channels=3 and hidden_channels=[16, 16, 1] as arguments\n", + "mlp = instantiate(omegaconf_example_config.mlp)\n", + "print(\"\\nInstantiated MLP:\", mlp, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "2ba731ac", + "metadata": {}, + "source": [ + "In yaml formatting the above config will look something like:\n", + "\n", + "```\n", + "# example_config.yaml\n", + "\n", + "sub_dict:\n", + " a: 1\n", + " b: 2\n", + " l: [1, 2, 3]\n", + "\n", + "# uses the custom resolver \"get_len\" to get the length of list l\n", + "len_of_l: ${get_len:${sub_dict.l}}\n", + "\n", + "# use this structure to cross-reference within a config: ${sub_dict.l}\n", + "mlp:\n", + " # used by Hydra to instantiate an object of the given class\n", + " _target_: mldft.ml.models.components.mlp.MLP\n", + " in_channels: 3\n", + " hidden_channels: [16, 16, 1]\n", + "```\n", + "\n", + "All our configs (in hierarchical structure are collected in the `configs` folder). The highest level config for model training to start from is the [configs/ml/train.yaml](../../configs/ml/train.yaml)." + ] + }, + { + "cell_type": "markdown", + "id": "e928e93a", + "metadata": {}, + "source": [ + "## 3 Config for model training\n", + "\n", + "Now, we want to load the config for the actual model training. Additionally, we load the data and create batches which can easierly be handled by the model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a6c66c", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf.dictconfig import DictConfig\n", + "\n", + "# load the config as Omegaconf Dict config for training a model\n", + "# with the defaut settings for data, optimizer, transforms, basis set, etc.\n", + "# this already handles the communication and combination of the different config files, e.g. for data and the model\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule = instantiate(config.data.datamodule)\n", + "datamodule.setup(stage=\"fit\")\n", + "datamodule.batch_size = 4 # set batch size to 4 (relatively small) for demonstration purposes\n", + "train_loader = datamodule.train_dataloader()" + ] + }, + { + "cell_type": "markdown", + "id": "9ec87358", + "metadata": {}, + "source": [ + "## 4 Overriding the default config\n", + "\n", + " As we not always want to use the default config, here is an examples of how to override settings. \n", + "In more detail, we now wish to use the QMugs dataset instead of the default QM9 dataset.\n", + "\n", + "Below, we also prepare the dataset for the fit stage of training, i.e. we use a smaller subset of the data for training and validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b1c9ddc", + "metadata": {}, + "outputs": [], + "source": [ + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config_qmugs = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # this overrides the data used to the qmugs dataset\n", + " \"data.dataset_name=QMUGS_perturbed_fock\", # with the dot we override a nested field\n", + " \"data/transforms=no_basis_transforms\", # with the / we override a whole file\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config_qmugs.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule_qmugs = instantiate(config_qmugs.data.datamodule)\n", + "datamodule_qmugs.setup(stage=\"fit\") # prepare the datasets\n", + "print(f\"Length of qmugs train set: {len(datamodule_qmugs.train_set)}\")\n", + "print(f\"Length of qmugs val set: {len(datamodule_qmugs.val_set)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7c593fa7", + "metadata": {}, + "source": [ + "To get a better intuition on how the QMugs datset is different from QM9 in terms of complexity, we visualize a QMugs molecule below. For visualizations of example QM9 molecules, please have a look at [Tutorial 2](tutorial_2_visualization.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df3daf4", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# keep only the program name so downstream parsers don't see Jupyter's -f=...\n", + "sys.argv = sys.argv[:1]\n", + "\n", + "import pyvista\n", + "\n", + "from mldft.utils.molecules import build_molecule_ofdata\n", + "from mldft.utils.visualize_3d import get_sticks_mesh_dict\n", + "\n", + "basis_info_qmugs = instantiate(config_qmugs.data.basis_info)\n", + "\n", + "# look at a qmugs molecule to see that they are larger than qm9 molecules\n", + "for sample_qmugs in datamodule_qmugs.train_set:\n", + " if sample_qmugs.mol_id.startswith(\"qmugs\"):\n", + " print(\"Found a qmugs molecule:\", sample_qmugs.mol_id)\n", + " break\n", + "\n", + "mol_qmugs = build_molecule_ofdata(sample_qmugs, basis=basis_info_qmugs.basis_dict)\n", + "\n", + "\n", + "# this give a ball and stick model of the molecule\n", + "molecule_mesh = get_sticks_mesh_dict(mol_qmugs)\n", + "molecule_mesh[\"opacity\"] = 1\n", + "\n", + "# plot the molecule and the global frame using pyvista:\n", + "pyvista.set_jupyter_backend(\"html\")\n", + "pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)\n", + "pl.camera_position = \"zx\"\n", + "pl.enable_parallel_projection()\n", + "pl.add_mesh(**molecule_mesh)\n", + "pl.enable_shadows()\n", + "pl.reset_camera(\n", + " bounds=0.9\n", + " * np.stack([mol_qmugs.atom_coords().min(0), mol_qmugs.atom_coords().max(0)], axis=1).flatten()\n", + ")\n", + "\n", + "img = pl.show(screenshot=True, window_size=(800, 400))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d5e4deb2", + "metadata": {}, + "source": [ + "## Appendix 1: Self-created custom resolver\n", + "\n", + "Above, get_len is already defined and registered as custom resolver in mldft.utils.omegaconf_resolvers.\n", + "But, we can also define our own omega conf custom resolver:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "016019c1", + "metadata": {}, + "outputs": [], + "source": [ + "# check if \"sum is already registered\"\n", + "if OmegaConf.has_resolver(\"sum\"):\n", + " print(\"sum is already registered\")\n", + "else:\n", + " print(\"registering sum\")\n", + " # register a custom resolver \"sum\" that sums up all its arguments\n", + " OmegaConf.register_new_resolver(\"sum\", lambda *args: sum(args))\n", + "\n", + "example_config = {\n", + " \"sub_dict\": {\"a\": 17, \"b\": -5},\n", + " \"sum_a_b\": \"${sum:${sub_dict.a}, ${sub_dict.b}}\", # this uses the custom resolver \"get_len\" to get the length of list l\n", + " # use this structure to cross-reference within a config ${sub_dict.l}\n", + "}\n", + "\n", + "omegaconf_example_config = OmegaConf.create(example_config)\n", + "# if you print the config naively, it shows just the strings\n", + "print(\"OmegaConf config:\", omegaconf_example_config)\n", + "# BUT if you access the value, it resolves the string using the custom resolver\n", + "print(\"Value from accessing len_of_l\", omegaconf_example_config.sum_a_b) # prints 3" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_5_mldftlitmodel.ipynb b/notebooks/tutorial/tutorial_5_mldftlitmodel.ipynb new file mode 100644 index 0000000..15272da --- /dev/null +++ b/notebooks/tutorial/tutorial_5_mldftlitmodel.ipynb @@ -0,0 +1,614 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 5: MLDFT lit module\n", + "\n", + "In this Section, we will dive into the model structure and explain how pytorch lightning is used." + ] + }, + { + "cell_type": "markdown", + "id": "23647f88", + "metadata": {}, + "source": [ + "## 0 Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import rich\n", + "import torch\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files like \"get_len\" to get lengths of lists\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "\n", + "# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "60127dec", + "metadata": {}, + "source": [ + "## 1 Config and data loading\n", + "\n", + "As a first step, we have to load the \"train.yaml\" config as a OmegaConf Dict config. For now, we don't use any overwrites, but just use the default setting for data, optimizer, transforms, basis set, etc. As the [hydra \"tree\" structure](.notebooks/tutorial_4_hydra_omegaconf.ipynb) is used, this already handles the communication and combination of the different config files, e.g. for data and the model.\n", + "\n", + "After the data is loaded, we focus for demonstration purposes on one individual sample molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a6c66c", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf.dictconfig import DictConfig\n", + "\n", + "from mldft.utils.molecules import build_molecule_ofdata\n", + "\n", + "# the following initialize already handles the communication and combination\n", + "# of the different config files, e.g. for data and the model\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " # Add trainer overrides for demonstration purposes\n", + " \"trainer.max_epochs=1\",\n", + " \"+trainer.limit_train_batches=1\",\n", + " \"+trainer.limit_val_batches=1\",\n", + " \"+trainer.enable_checkpointing=False\",\n", + " \"data.datamodule.num_workers=0\",\n", + " ],\n", + " )\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule = instantiate(config.data.datamodule)\n", + "datamodule.setup(stage=\"fit\")\n", + "datamodule.batch_size = 4 # set batch size to 4 (relatively small) for demonstration purposes\n", + "train_loader = datamodule.train_dataloader()\n", + "\n", + "sample = datamodule.train_set[0]\n", + "\n", + "# need basis info to build a pySCF molecule object\n", + "# see below for more details on basis_info\n", + "basis_info = instantiate(config.data.basis_info)\n", + "\n", + "# build a pySCF molecule object from the OFData sample\n", + "mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "9e57a243", + "metadata": {}, + "source": [ + "Next, we want to take a look at the machine learning model used to predict\n", + "the kinetic energy (and possibly other energies) from a given electron density.\n", + "The main module which handles the training is the MLDFTLitModule.\n", + "\n", + "For this, let's take a look at the part of the config that is used to configure the model.\n", + "It is a very long and nested config, which specifies everything needed for training.\n", + "\n", + "You will find in there amongst other things:\n", + "* The optimizer used to update the model parameters during training.\n", + "* The learning rate scheduler used to adjust the learning rate after every epoch during training.\n", + "* The loss function used to compute the training loss: It is used for backpropagation\n", + "to compute the gradients of the model parameters which will be applied to update each parameter.\n", + "* The net which is the main neural network architecture that takes the batched sample as input\n", + "and outputs a prediction for the energy.\n", + "* The basis_info which specifies the basis set used to represent the density.\n", + "* The dataset_statistics used to standardize the input densities and the output energy labels to\n", + "improve and stabilize training.\n", + "* The density_optimizer and denop_settings which specify how density optimization is performed with a trained model.\n", + "\n", + "Question: Can you find the optimizer and the learning rate that we use for training?\n", + "Question: Can you also find the optimizer and learning rate that we use during density optimization (denop)?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9465fb22", + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "from mldft.ml.models.mldft_module import MLDFTLitModule\n", + "\n", + "rich.print(dict_to_tree(config.model, guide_style=\"dim\"))\n", + "\n", + "# The getattribute function prints a message whenever a hook method is called\n", + "# Therfore, we can later see in the output which hooks are called during training\n", + "# (e.g., on_train_start, training_step, etc.)\n", + "# find more information in the output after the trainer is called\n", + "\n", + "\n", + "def getattribute(self, name):\n", + " attr = object.__getattribute__(self, name)\n", + " hook_prefixes = (\"on_\", \"training_\", \"validation_\", \"test_\", \"predict_\")\n", + " if callable(attr) and any(name.startswith(p) for p in hook_prefixes):\n", + "\n", + " @functools.wraps(attr)\n", + " def wrapper(*args, **kwargs):\n", + " print(f\"Our lightning module is now calling: {name}\")\n", + " return attr(*args, **kwargs)\n", + "\n", + " return wrapper\n", + " return attr\n", + "\n", + "\n", + "mldft_module = instantiate(config.model)\n", + "mldft_module.__class__.__getattribute__ = getattribute\n", + "# the MLDFTLitModule inherits from pl.LightningModule\n", + "# which is a PyTorch Lightning specific class that handles the training loop\n", + "print(\"Successfully instantiated model:\", type(mldft_module))" + ] + }, + { + "cell_type": "markdown", + "id": "4ea3a691", + "metadata": {}, + "source": [ + "# 2 Forward pass through model\n", + "\n", + "Now, let's do a forward pass through the model with one batch of data. \n", + "\n", + "The forward output consists of three parts:\n", + "* First, the predicted energy for the given input electron density (in our case kinetic energy + XC energy).\n", + "* Second, the predicted gradients of the energy with respect to the input density coefficients.These are computed via automatic differentiation (autodiff) in PyTorch (see example below)\n", + "* Third, a direct prediction of the ground state density coefficients (or rather the difference between the input density coeffs and the ground state density coeffs).\n", + "The latter, we usually don't use during training, see coefficient_loss has weight 0.0 in the config above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "017b4b3a", + "metadata": {}, + "outputs": [], + "source": [ + "# let's do a forward pass through the model with one batch of data\n", + "# this is taking some time because the model was not moved to the GPU:\n", + "print(\"mldft_module.device:\", mldft_module.device, \"\\n\")\n", + "batch = next(iter(train_loader))\n", + "forward_out = mldft_module.forward(batch) # which does the same as mldft_module(batch)\n", + "\n", + "print(\"Model output:\", forward_out)" + ] + }, + { + "cell_type": "markdown", + "id": "63a2365a", + "metadata": {}, + "source": [ + "## 3 Training step\n", + "\n", + "This was a single forward pass through the model, but does not yet look much like training\n", + "instead we can make a training_step with the model. \n", + "\n", + "In more detail, during each training step, the following happens:\n", + "1. The training loop calls the `training_step` method of the `MLDFTLitModule`.\n", + "2. Inside `training_step`, the model processes the input batch to produce predictions.\n", + "3. The loss function computes the loss by comparing the predictions to the true labels.\n", + "4. Additional training metrics are computed and logged.\n", + "5. The compuatational graph is saved for a backward pass.\n", + "\n", + "Afterwards, the optimizer uses the loss to perform backpropagation and update the model weights.\n", + "\n", + "For more information on the optimizer, the Appendix 3 in this notebook can be recommended.\n", + "\n", + "To do the training step, we will need a trainer attached to the model.\n", + "(By the way, the model which we have just loaded is untrained, so the loss will be very large.)\n", + "\n", + "**Command for classical training:**\n", + "\n", + "Usually you would start a training with a command similar to this one: \n", + "```CUDA_VISIBLE_DEVICES=2 python mldft/ml/train.py experiment=str25/qm9_tf```\n", + "\n", + "Quick note: With```CUDA_VISIBLE_DEVICES```, you select which GPU to run the job on. Please, check after accessing the server which GPU is currently free with the following command: ```gpustat```.\n", + "\n", + "With the rest of the command you call the main training script with experiment specific config options." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c5f2e5d", + "metadata": {}, + "outputs": [], + "source": [ + "# instantiate the trainer\n", + "# for that, remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config.paths.output_dir = \"example_path\"\n", + "config.paths.work_dir = \"example_path\"\n", + "trainer = instantiate(config.trainer)\n", + "print(\"Successfully instantiated trainer:\", type(trainer))\n", + "mldft_module.trainer = trainer # add the trainer to the module\n", + "\n", + "# also, let us disable the logging for this tutorial:\n", + "mldft_module.log = lambda *args, **kwargs: None\n", + "mldft_module.log_dict = lambda *args, **kwargs: None\n", + "\n", + "train_step_out = mldft_module.training_step(batch)\n", + "print(\"Output of training step:\", train_step_out.keys())" + ] + }, + { + "cell_type": "markdown", + "id": "4eea2ed3", + "metadata": {}, + "source": [ + "The training step returns a dictionary containing the following things:\n", + "* 'loss': the total loss computed for the batch, which is used for backpropagation\n", + "* 'model_outputs': containing the three outputs of the forward pass ('pred_energy', 'pred_gradients', 'pred_diff')\n", + "* 'projected_gradient_difference': the difference between the predicted and true energy gradients projected (to preserve the number of electrons)" + ] + }, + { + "cell_type": "markdown", + "id": "a39622b2", + "metadata": {}, + "source": [ + "## 4 Running a full training epoch\n", + "\n", + "The LitModule uses a specific \"syntax\" to handle these training details under the hood.\n", + "For instance, the backwards on the loss and also the optimizer step are performed automatically.\n", + "The lightning model uses for that by default the \"loss\" value\n", + "returned in the output dictionary of the training step.\n", + "\n", + "Furthermore, in the lightning module we do not see an explicit training loop\n", + "that loops over the batches in the train_loader.\n", + "This is automatically handled by the pytorch lightning trainer that combines the\n", + "model with the dataloader(s), e.g. in\n", + "trainer.fit(model, train_dataloader, val_dataloader)\n", + "\n", + "In addition, the logging to the tensorboard is handled simply via self.log in the lightning module\n", + "and uses the logger that is attached to the trainer (via the trainer that is attached to the model).\n", + "\n", + "Additionally, the lightning module handles the validation loop automatically.\n", + "It works via the validation_step method similar to the training loop based on the training_step method.\n", + "There are even quite a bit more methods that follow a standard syntax and can be used to\n", + "achieve certain behavior during training, e.g. on_epoch_start, on_epoch_end, etc.\n", + "see https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html" + ] + }, + { + "cell_type": "markdown", + "id": "5cc4fa49", + "metadata": {}, + "source": [ + "Additional information to the lightning module:\n", + "\n", + "1. The **Lightning module** holds our machine learning model at its core and defines what to train and how to train. There are different subclasses implemented in the code:\n", + "* The **model architetcture** is defined under \"mldft_module.net\". Its forward pass is called in the ```forward``` method of the mldft_module. \n", + "+ *Analogy:* You can think of ```__init__``` as the ingredient list of a recipe and ```forward``` as a simple instruction. However, the magic happens in between...\n", + "* The **training logic** lies in the ```training step```.\n", + "* The **validation and test logic** is in the ```validation_step``` and ```test_step```.\n", + "* The **optimizer(s)** are set in ```configure_optimizers```.\n", + "+ *Analogy:* To continue with the analogy, you can think of a the additional functions in the lightning module class as your way to optimize the recipe. With each training step you learn new things and adjuste the recipe (parameters). To make sure your changes are actually good, you also continuously validate it. The optimizatation process happens in this cooking example in your brain as you consider adding more salt etc. In the code, the opimization process happens in the optimizer function.\n", + "\n", + "2. The **LightningDataModule** organizes all the data-related logic:\n", + "* With the ```setup``` function in the DataModule class one defines how to load the data.\n", + "* Also DataLoaders have to be created with in the ```train_dataloader``` and the ```val_dataloader``` etc.\n", + "* Optionally, one can define a preprocessing of the data.\n", + "\n", + "+ Note: The LightningDataModule helps keeping data handling clean and separate from the model logic.\n", + "\n", + "3. Think of the **Trainer** as an orchestrator, which handels:\n", + "* Training loops\n", + " * Note: You don't manualy write the training loops in Lightning - the ```Trainer```automates them.\n", + "* Validation & testing\n", + "* Logging\n", + "* Checkpointing\n", + "* Device placement (CPU, GPU)\n", + "* Distributing training \n", + "\n", + "Important note: In our file structure, you can find a [\"train.py\"](../../mldft/ml/train.py) file which is the main entry point for training. It instantiates all relevant components, i.e. Trainer, datamodule, lightning module, etc. \n", + "\n", + "In a subfolder data, there is a [\"datamodule.py\"](../../mldft/ml/data/datamodule.py) file associated with the DataModule and lastly in the folder models a [\"mldft_module.py\"](../../mldft/ml/models/mldft_module.py) file which handels the core of the model. \n", + "[Config files](../../configs/ml/train.yaml) (as discussed in [Tutorial 4](./tutorial_4_hydra_omegaconf.ipynb)) are the place where most the variables are stored for the training." + ] + }, + { + "cell_type": "markdown", + "id": "17c5303a", + "metadata": {}, + "source": [ + "#### Now it all comes together in our very first training epoch\n", + "The lightning module follows a specific syntax of methods which will be executed in a very specific order during training. For instance, the `on_train_epoch_start` method will be executed every time when a training epoch is started (as one might have guessed). Similarly, there is the`on_before_backward` method which is called shortly before the backward or the `on_validation_batch_end` that is called after the processing of each validation batch. \n", + "\n", + "Below, we will execute a \"full training\" (but only for one epoch) and log all these methods in the order in which they are executed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd9ac67", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# freshly instantiate the trainer for a clean state\n", + "trainer = instantiate(config.trainer, enable_progress_bar=False)\n", + "\n", + "# disable all user warnings for the following trainer.fit call\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\", UserWarning)\n", + "\n", + " trainer.fit(mldft_module, datamodule=datamodule);" + ] + }, + { + "cell_type": "markdown", + "id": "8f57bcfa", + "metadata": {}, + "source": [ + "# Appendix 1: Automatic differentiaton\n", + "\n", + "We will have a short intermezzo on understanding how automatic differentiation (via backpropagation) works in PyTorch.\n", + "When you have a tensor with requires_grad=True, all operations on that tensor are tracked\n", + "and a computation graph is built in the background.\n", + "Then when you call backward() on a tensor, the gradients of that tensor with respect to\n", + "all tensors that have requires_grad=True and were used to compute that tensor\n", + "are computed via backpropagation through the computation graph.\n", + "Here is a simple example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96e6eda8", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.tensor(2.0, requires_grad=True)\n", + "y = x**2 + 3 * x + 1\n", + "print(\"y:\", y, \"\\n\")\n", + "y.backward() # this computes the gradient of y with respect to x via backpropagation\n", + "print(\"dy/dx:\", x.grad, \"\\n\") # dy/dx = 2*x + 3 = 2*2 + 3 = 7\n", + "\n", + "# small subtlety: if you do multiple operations on a tensor\n", + "# the gradients are accumulated in the .grad attribute\n", + "x = torch.tensor(2.0, requires_grad=True)\n", + "y1 = x**2\n", + "y2 = x**3\n", + "y1.backward() # this computes the gradient of y1 with respect to x via backpropagation\n", + "print(\"dy1/dx:\", x.grad) # dy1/dx = 2*x = 2*2 = 4\n", + "y2.backward() # this computes the gradient of y2 with respect to x via backpropagation\n", + "print(\"dy1/dx + dy2/dx:\", x.grad, \"\\n\") # dy1/dx + dy2/dx = 2*x + 3*x**2 = 2*2 + 3*2**2 = 16\n", + "\n", + "# to zero the gradients, you can use the zero_() method\n", + "x.grad.zero_()\n", + "print(\"zeroed gradients:\", x.grad, \"\\n\")\n", + "\n", + "# detach can be used to stop tracking operations on a tensor\n", + "x = torch.tensor(2.0, requires_grad=True)\n", + "y = x**2\n", + "z = y.detach() + 3 * x # detach stops tracking operations on y\n", + "print(\"z:\", z)\n", + "z.backward() # this computes the gradient of z with respect to x via backpropagation\n", + "print(\"dz/dx:\", x.grad, \"\\n\") # dz/dx = 3, since y was detached\n", + "\n", + "# by default, after one calls backward(), the computation graph is deleted to save memory\n", + "# if you want to call backward() multiple times on the same graph, for instance to compute a second derivative\n", + "# (as we actually do in our project when we first compute the energy gradient w.r.t. the density\n", + "# and then use that energy gradient to compute a loss function that is then used\n", + "# for another backward call to update the model parameters)\n", + "# in this case you need to specify retain_graph=True\n", + "x = torch.tensor(2.0, requires_grad=True)\n", + "y = x**6\n", + "dy_dx = torch.autograd.grad(y, x, create_graph=True, retain_graph=True)[\n", + " 0\n", + "] # this computes dy/dx = 6*x**5\n", + "print(\"dy/dx:\", dy_dx)\n", + "d2y_dx2 = torch.autograd.grad(dy_dx, x)[0] # this computes d2y/dx2 = 30*x**4 = 30*2**4 = 480\n", + "print(\"d2y/dx2:\", d2y_dx2, \"\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "85af7337", + "metadata": {}, + "source": [ + "## Appendix 2: Partial\n", + "\n", + "Since the optimizer in the config is only partially (\"_partial_\") initialized, we want to take a look at what this actually means in the example below.\n", + "\n", + "As you might know the standard normal distribution is a special case of a classical Gaussian distribution. To include this knowledge but simplify futher calling, we could use the partial function and with it specify the necessary mean and standard deviation properties that make a Gaussian a standard normal distribution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70d5b3ea", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "\n", + "def gaussian(x, mean, std):\n", + " return torch.exp(-0.5 * ((x - mean) / std) ** 2) / (std * (2 * torch.pi) ** 0.5)\n", + "\n", + "\n", + "standard_normal = partial(gaussian, mean=0.0, std=1.0)\n", + "print(standard_normal)\n", + "# standard_normal is now a function that only takes x as argument\n", + "# and mean and std are fixed to 0.0 and 1.0 respectively\n", + "x = torch.linspace(-5, 5, steps=100)\n", + "y = standard_normal(x=x)\n", + "plt.plot(x.numpy(), y.numpy())\n", + "plt.title(\"Standard normal distribution\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"Probability density\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "895a4ad4", + "metadata": {}, + "source": [ + "## Appendix 3: Optimizer\n", + "\n", + "Now, we want to examplify how the updating of the model parameters works during training\n", + "for that we need to attach an optimizer to the model. \n", + "\n", + "Next, if you look carefully you will find that the optimizer in the config is only partially (\"_partial_\") initialized. This means that some of the arguments are missing and will be filled in later (more info see Appendix 1). In particular the model parameters that should be optimized are missing, because the model parameters are not known before the model is instantiated.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9be9707e", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer_partially_initialized = instantiate(config.model.optimizer)\n", + "optimizer = optimizer_partially_initialized(params=mldft_module.parameters())\n", + "\n", + "# an alternative more compact option would be the following:\n", + "# optimizer = instantiate(config.model.optimizer, params=mldft_module.parameters())\n", + "\n", + "mldft_module.optimizer = optimizer # add the optimizer to the module\n", + "print(\"Successfully instantiated optimizer and linked it with model parameters:\", type(optimizer))\n", + "\n", + "for name, model_param in mldft_module.named_parameters():\n", + " print(name, model_param.shape)\n", + " break # just the first parameter\n", + "\n", + "# first we zero the gradients of the model parameters\n", + "mldft_module.optimizer.zero_grad()\n", + "# print the gradient of the first parameter (should be None after zeroing the grads):\n", + "print(\"Gradient of first parameter before backward:\", model_param.grad)\n", + "# then we call backward on the loss to compute the gradients of the model parameters\n", + "try:\n", + " train_step_out[\"loss\"].backward(\n", + " retain_graph=False\n", + " ) # so that this cell can in principle be run multiple times\n", + " # now the gradients of the model parameters are stored in the .grad attribute of each parameter\n", + " print(\"Gradient of first parameter after backward:\", model_param.grad, model_param.grad.shape)\n", + " old_model_param = (\n", + " model_param.clone().detach()\n", + " ) # clone and detach to keep a copy of the old parameters\n", + "\n", + " # now, we can update the model parameters with one step of the optimizer\n", + " mldft_module.optimizer.step()\n", + " print(\n", + " \"Maximum relative change in first parameter after one optimizer step:\",\n", + " ((model_param - old_model_param) / old_model_param).abs().max(),\n", + " )\n", + "except RuntimeError as e:\n", + " print(\"Caught expected RuntimeError due to multiple backward calls on the same graph.\")" + ] + }, + { + "cell_type": "markdown", + "id": "792d8b11", + "metadata": {}, + "source": [ + "# Appendix 4: Dataset statistics\n", + "\n", + "One small but not to be underestimated detail of our training are the dataset statistics.\n", + "These are used to standardize the input densities and the output energy labels\n", + "to improve and stabilize training.\n", + "As such, the dataset statistics are specific to which dataset (QM9 or QMUGS) and energy label is used (E_kin, E_xc, E_kin + E_xc, etc.),\n", + "as well as to which transforms are applied to the input densities (e.g. local_frames_global_symmetric_natrep).\n", + "\n", + "The dataset_statistics are essentially a .zarr folder, which can be seen in the config path. After instantiating it, we see for each relevant quantity some additonal statistical values, like the mean and std, as well as the abs_max value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb328f02", + "metadata": {}, + "outputs": [], + "source": [ + "# let's take a look at the respective part in the config to verify that:\n", + "rich.print(dict_to_tree(config.data.dataset_statistics, guide_style=\"dim\"))\n", + "\n", + "from mldft.ml.preprocess.dataset_statistics import DatasetStatistics\n", + "\n", + "dataset_statistics = instantiate(config.data.dataset_statistics)\n", + "print(\"Successfully instantiated dataset_statistics:\", type(dataset_statistics))\n", + "dataset_statistics" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_6_density_optimization.ipynb b/notebooks/tutorial/tutorial_6_density_optimization.ipynb new file mode 100644 index 0000000..d59c7e3 --- /dev/null +++ b/notebooks/tutorial/tutorial_6_density_optimization.ipynb @@ -0,0 +1,660 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1be328", + "metadata": {}, + "source": [ + "# Tutorial 6: Density Optimization\n", + "\n", + "In this notebook, you will learn about the density optimization (denop) method.\n", + "In more detail, the density will converge towards the groundstate as enough iterations are performed." + ] + }, + { + "cell_type": "markdown", + "id": "cbffeece", + "metadata": {}, + "source": [ + "## 0 Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbe5acbd", + "metadata": {}, + "outputs": [], + "source": [ + "# import necessary packages\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import rich\n", + "import torch\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "# this makes sure that code changes are reflected without restarting the notebook\n", + "# this can be helpful if you want to play around with the code in the repo\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from mldft.ml.models.mldft_module import MLDFTLitModule\n", + "\n", + "# omegaconf is used for configuration management\n", + "# omegaconf custom resolvers are small functions used in the config files like \"get_len\" to get lengths of lists\n", + "from mldft.utils import omegaconf_resolvers # this registers omegaconf custom resolvers\n", + "from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree" + ] + }, + { + "cell_type": "markdown", + "id": "9fd682dd", + "metadata": {}, + "source": [ + "## 1 (Config) settings for denop\n", + "\n", + "The main denisty optimization block is applied after the model is trained. \n", + "In the command line, it could be execute by the following command: \n", + "\n", + "```CUDA_VISIBLE_DEVICES=6 python mldft/ofdft/run_density_optimization.py run_path=\"/export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\\qm9_tf\" n_molecules=10 split=test```\n", + "\n", + "However, it can also be used during training to improve the model's performance. \n", + "For the following notebook, we will use a pretrained model, which can be accessefd by a checkpoint path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a708b7b", + "metadata": {}, + "outputs": [], + "source": [ + "# download a small dataset from huggingface that contains QM9 and QMugs data\n", + "# and change the DFT_DATA environment variable to the directory where the data is stored\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "REPO_ID = \"sciai-lab/minimal_data_QM9_QMugs\"\n", + "\n", + "print(\"Using tiny datasets\")\n", + "\n", + "# clone the full repo\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "\n", + "os.environ[\n", + " \"HF_HUB_DISABLE_PROGRESS_BARS\"\n", + "] = \"1\" # to avoid problems with the progress bar in some environments\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "\n", + "data_path = snapshot_download(\n", + " repo_id=\"sciai-lab/minimal_data_QM9_QMugs\", cache_dir=CACHE_DIR, repo_type=\"dataset\"\n", + ")\n", + "\n", + "dft_data = os.environ.get(\"DFT_DATA\", None)\n", + "os.environ[\"DFT_DATA\"] = data_path\n", + "print(\n", + " f\"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e46e188", + "metadata": {}, + "outputs": [], + "source": [ + "import contextlib\n", + "\n", + "# load the model from the checkpoint (downloaded from our huggingface model repo):\n", + "\n", + "# https://huggingface.co/docs/datasets/cache#cache-directory\n", + "# The default cache directory is `~/.cache/huggingface/datasets`\n", + "# You can change it by setting this variable to any path you like\n", + "CACHE_DIR = None # e.g. change it to \"./hf_cache\"\n", + "\n", + "# https://huggingface.co/sciai-lab/structures25/tree/main\n", + "print(\"Using QM9 model\")\n", + "qm9_model_path = hf_hub_download(\n", + " repo_id=\"sciai-lab/structures25\",\n", + " filename=\"trained-on-qm9/trained-on-qm9.ckpt\",\n", + " cache_dir=CACHE_DIR,\n", + ")\n", + "\n", + "\n", + "@contextlib.contextmanager\n", + "def _safe_map_location():\n", + " tls = torch._utils._thread_local_state\n", + " had_attr = hasattr(tls, \"map_location\")\n", + " if not had_attr:\n", + " setattr(tls, \"map_location\", None)\n", + " try:\n", + " yield\n", + " finally:\n", + " if hasattr(tls, \"map_location\") and not had_attr:\n", + " delattr(tls, \"map_location\")\n", + "\n", + "\n", + "def safe_load_from_ckpt(path):\n", + " for attempt in (1, 2): # retry once, because the first call can trip the bug\n", + " try:\n", + " with _safe_map_location():\n", + " return MLDFTLitModule.load_from_checkpoint(path, map_location=\"cpu\")\n", + " except AttributeError as e:\n", + " if \"map_location\" in str(e) and attempt == 1:\n", + " continue\n", + " raise\n", + "\n", + "\n", + "mldft_module_trained = safe_load_from_ckpt(qm9_model_path)\n", + "mldft_module_trained.eval() # set model to eval mode\n", + "\n", + "print(\"Successfully loaded trained model from checkpoint:\", type(mldft_module_trained))" + ] + }, + { + "cell_type": "markdown", + "id": "c4428484", + "metadata": {}, + "source": [ + "For this model, we have to use the \"local_frames_global_natrep_add_lframe\" transformation, which we can access by overwriting the default.\n", + "\n", + "But careful with the dataset statistics! The dataset_statistics used to create the SAD intitial guess is the starting point for denop. In the denop we, find two dataset statistics, one in the the model and one in the data section. They differ by the transformations applied. The initial SAD guess is built completely without any transforms applied and is only later transformed using sample.transformation_matrix shortly before starting the denop. To ensure no transforms are applied we use use the \"config_denop.model.dataset_statistics and a specifically called omegaconf custom resolver (\"to_no_basis_transforms_dataset_statistics\") handels the rest.\n", + "\n", + "With the \"datamodule.test_dataloader()\", we prepare a test set for density optimization. \n", + "\n", + "Additionally, we want to batch the data and for demonstartion purposes only use the first batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7b1a0bc", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import open_dict\n", + "\n", + "from mldft.ml.data.components.convert_transforms import PrepareForDensityOptimization\n", + "from mldft.ofdft.functional_factory import requires_grid\n", + "from mldft.utils.omegaconf_resolvers import to_no_basis_transforms_dataset_statistics\n", + "\n", + "# for this model we need to use local frames so we need slightly different transforms\n", + "with initialize(version_base=None, config_path=\"../../configs/ml\"):\n", + " config_denop = compose(\n", + " config_name=\"train.yaml\",\n", + " overrides=[\n", + " \"data/transforms=local_frames_global_natrep_add_lframes\",\n", + " \"data.dataset_name=QM9_perturbed_fock\", # this will no longer be necessary once the \"fixed\" is removed from the dataset_name\n", + " \"data.transforms.use_cached_data=False\", # to use untransformed data paths\n", + " ],\n", + " )\n", + " # IMPORTANT: for denop we need to instantiate the model.dataset_statistics, the basis_info and the datamodule\n", + " # use open_dict envrionment to modify the config since it is frozen by default\n", + " with open_dict(config_denop):\n", + " config_denop.model.dataset_statistics = config_denop.data.dataset_statistics.copy()\n", + " config_denop.model.dataset_statistics.path = to_no_basis_transforms_dataset_statistics(\n", + " dataset_statistics_path=config_denop.model.dataset_statistics.path,\n", + " transformation_name=config_denop.data.transforms.name,\n", + " )\n", + "\n", + "model_dataset_statistics_for_denop = instantiate(config_denop.model.dataset_statistics)\n", + "\n", + "# remove the hydra specific stuff that only works in @hydra.main decorated functions\n", + "config_denop.paths.output_dir = \"example_path\"\n", + "\n", + "datamodule = instantiate(config_denop.data.datamodule)\n", + "datamodule.setup(stage=\"test\")" + ] + }, + { + "cell_type": "markdown", + "id": "e4747581", + "metadata": {}, + "source": [ + "Next, some transformations to the right data type have to be done." + ] + }, + { + "cell_type": "markdown", + "id": "3135d31c", + "metadata": {}, + "source": [ + "Additionally, the denop settings and the density optimizer settings in the config need to be instantiated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c654ed8", + "metadata": {}, + "outputs": [], + "source": [ + "# instantiate the ofdft config:\n", + "with initialize(version_base=None, config_path=\"../../configs/ofdft\"):\n", + " config_ofdft = compose(\n", + " config_name=\"ofdft.yaml\",\n", + " )\n", + "\n", + "\n", + "# print configs for denop:\n", + "print(\"\\nConfig for density_optimizer:\")\n", + "rich.print(dict_to_tree(config_ofdft.optimizer, guide_style=\"dim\"))\n", + "\n", + "# for denop our model needs the following additional things:\n", + "# denop_settings = instantiate(config_denop.model.denop_settings)\n", + "density_optimizer = instantiate(config_ofdft.optimizer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14991228", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.dataset import OFDataset\n", + "\n", + "# set pytorch dtype default to float64 for better numerical accuracy in denop\n", + "from mldft.ml.data.components.of_data import Representation\n", + "\n", + "torch.set_default_dtype(torch.float64)\n", + "\n", + "# customise the transformations for density optimization:\n", + "basis_info = instantiate(config_denop.data.basis_info)\n", + "transforms = instantiate(config_denop.data.transforms)\n", + "add_grid = requires_grid(\n", + " config_denop.data.target_key, config_ofdft.negative_integrated_density_penalty_weight\n", + ")\n", + "transforms.pre_transforms.insert(0, PrepareForDensityOptimization(basis_info, add_grid=add_grid))\n", + "transforms.add_transformation_matrix = True\n", + "transforms.use_cached_data = False\n", + "\n", + "dataset_kwargs = instantiate(config_denop.data.datamodule.dataset_kwargs)\n", + "dataset_kwargs.update(\n", + " {\n", + " \"limit_scf_iterations\": -1,\n", + " \"additional_keys_at_ground_state\": {\n", + " \"of_labels/energies/e_electron\": Representation.SCALAR,\n", + " \"of_labels/energies/e_ext\": Representation.SCALAR,\n", + " \"of_labels/energies/e_hartree\": Representation.SCALAR,\n", + " \"of_labels/energies/e_kin\": Representation.SCALAR,\n", + " \"of_labels/energies/e_kin_plus_xc\": Representation.SCALAR,\n", + " \"of_labels/energies/e_kin_minus_apbe\": Representation.SCALAR,\n", + " \"of_labels/energies/e_kinapbe\": Representation.SCALAR,\n", + " \"of_labels/energies/e_xc\": Representation.SCALAR,\n", + " \"of_labels/energies/e_tot\": Representation.SCALAR,\n", + " },\n", + " }\n", + ")\n", + "\n", + "denop_dataset = OFDataset(\n", + " paths=datamodule.test_set.paths,\n", + " num_scf_iterations_per_path=None,\n", + " basis_info=basis_info,\n", + " transforms=transforms,\n", + " **dataset_kwargs,\n", + ")\n", + "\n", + "sample_double = denop_dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ffcca24", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ml.data.components.convert_transforms import ToTorch\n", + "\n", + "# by default for denop we use double precision\n", + "print(\"sample.pos.dtype before ToTorch:\", sample_double.pos.dtype)\n", + "to_float_32 = ToTorch(float_dtype=torch.float32)\n", + "sample_float = sample_double.clone()\n", + "sample_float = to_float_32(sample_float)\n", + "print(\"sample.pos.dtype after ToTorch:\", sample_float.pos.dtype)\n", + "\n", + "# after converting to float we can do a simple forward pass through the trained model\n", + "forward_out_trained = mldft_module_trained.forward(sample_float)" + ] + }, + { + "cell_type": "markdown", + "id": "2e81da78", + "metadata": {}, + "source": [ + "## 2 Functional Factory and SAD guess\n", + "\n", + "Now, the config settings are prepared and instantiated and we can take a look at the \"FunctionalFactory\" and the Sad guess, which will be calles during the actual denop procedure.\n", + "\n", + "The \"FunctionalFactory\" is used to create an energy functional from the trained model that can be used for density optimization.\n", + "\n", + "The contributions returned are therefore: ```contributions = [mldft_module_trained, \"hartree\", \"nuclear_attraction\"]```\n", + "* With our model, we predict T_s + E_xc, i.e. the non-interacting kinetic energy plus the exchange-correlation energy.\n", + "* The hartree energy is the classical electron-electron repulsion energy based on the current density.\n", + "* The nuclear attraction energy is the attraction of the electrons to the nuclei based on the current density.\n", + "\n", + "--\n", + "* Also note: Since the nuclear repulsion energy does not depend on the density, it is not part of the functional\n", + "but is computed later on directly from the mol object via nuclear_repulsion = mol.energy_nuc()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dacc263", + "metadata": {}, + "outputs": [], + "source": [ + "# functional factory (which is a slighly more complicated thing)\n", + "from mldft.ofdft.functional_factory import FunctionalFactory\n", + "\n", + "func_factory = FunctionalFactory.from_module(\n", + " module=mldft_module_trained,\n", + " xc_functional=config_ofdft.xc_functional, # not used in our case since we predict T_s + E_xc\n", + " negative_integrated_density_penalty_weight=config_ofdft.negative_integrated_density_penalty_weight,\n", + " # the latter is zero by default (no penalty for regions with negative electron densities)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ce89a091", + "metadata": {}, + "source": [ + "Now, on to the SAD (Sum of Atomic Denisties):\n", + "\n", + "The SAD guess is a sum of independent atom-type specific densities that are based on dataset statistics\n", + "and for which the total number of electrons matches the total number of electrons in the molecule.\n", + "Even though there are also other first guess methods like MINAO or HÜCKEL (well established initial guesses already implemented in the `pyscf` package) or the option to learn the initial guess, we use the simple SAD guess as a default in STRUCTURES25 since it is cheapest of the ones listed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96346595", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ofdft.callbacks import ConvergenceCallback\n", + "from mldft.ofdft.density_optimization import density_optimization_with_label\n", + "from mldft.utils.sad_guesser import SADNormalizationMode\n", + "\n", + "# since we do use SAD (Sum of Atomic Densities) as initial density guess, we have to specify\n", + "# the following keyword arguments that are passed to the SAD guesser,\n", + "# see SADGuesser class for details:\n", + "\n", + "sad_guess_kwargs = dict(\n", + " dataset_statistics=model_dataset_statistics_for_denop,\n", + " normalization_mode=SADNormalizationMode.PER_ATOM_WEIGHTED,\n", + " basis_info=basis_info,\n", + " weigher_key=\"ground_state_only\",\n", + " spherical_average=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7bb736c5", + "metadata": {}, + "source": [ + "## 3 Denop process and results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b4f7530", + "metadata": {}, + "outputs": [], + "source": [ + "# change max number of interations:\n", + "density_optimizer.max_cycle = 10 # You migth want to change the number of iterations to speed things up, but this migth not guarantee convergence\n", + "\n", + "metric_dict, callback, energies_label, energy_functional = density_optimization_with_label(\n", + " sample=sample_double, # OFData sample object containing required tensors for the functional.\n", + " mol=sample_double.mol, # Molecule object used for the initial guess and building the grid (used for eval of XC functional).\n", + " optimizer=density_optimizer, # Optimizer used for the density optimization process.\n", + " func_factory=func_factory, # see above\n", + " callback=ConvergenceCallback(), # specifies which iteration to report as the converged result\n", + " # (in our case of \"last_iter\" as convergence criterion, this is simply the last iteration.)\n", + " initial_guess_str=config_ofdft.initialization, # in our case SAD is used as initial guess (see above)\n", + " max_xc_memory=config_ofdft.ofdft_kwargs.max_xc_memory, # XC is computed on the grid this defines an upper limit for the grid size\n", + " # not relevant when using e_kin_plus_xc as training target\n", + " # best doc string explanation is the following:\n", + " # Guess of the maximum memory that should be taken by the aos in MB. Total usage might be higher.\n", + " # Defaults to the pyscf default of 4000MB\n", + " normalize_initial_guess=config_ofdft.ofdft_kwargs.normalize_initial_guess, # Whether to normalize the initial guess to the correct number of electrons.\n", + " proj_minao_module=None, # Lightning module used to improve the initial guess from SAD to some learned initial guess.\n", + " sad_guess_kwargs=sad_guess_kwargs, # see above\n", + " convergence_criterion=config_ofdft.convergence_criterion, # in our case \"last_iter\", i.e. we simply take the last iteration and stop iterating if the gradient norm is below the convergence_tolerance\n", + " disable_printing=False, # Whether to disable printing of the optimization progress.\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1b7e5c84", + "metadata": {}, + "source": [ + "From the \"density_optimization_with_label\", the following properties are returned:\n", + "* metric_dict: Dictionary containing various metrics collected during the optimization process.\n", + "* callback: ConvergenceCallback object used to determine convergence.\n", + "* energies_label: Dictionary containing the energies computed during the optimization process, including the label energies.\n", + "* energy_functional: The energy functional used for the optimization.\n", + "\n", + "If you wnat to you can also look athe results individually:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a944853", + "metadata": {}, + "outputs": [], + "source": [ + "# contains metrics evaluating the final density and energy after denop (gs=ground_state)\n", + "metric_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be0fb1a3", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ofdft.energies import Energies\n", + "\n", + "# contains the different energies at the \"converged\" state\n", + "energies_label.energies_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2745c43", + "metadata": {}, + "outputs": [], + "source": [ + "# The callback contains the states of all iterations\n", + "print(\"Length of callback.states:\", len(callback.energy))\n", + "# that converged result can be obtained via:\n", + "print(\"Converged result:\")\n", + "callback.get_convergence_result()" + ] + }, + { + "cell_type": "markdown", + "id": "94626789", + "metadata": {}, + "source": [ + "## 4 Visualizations of results\n", + "\n", + "As a last step, let's visualize the density optimization process and see how it converges toward the tolance set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ce88518", + "metadata": {}, + "outputs": [], + "source": [ + "# plot how the gradient norm evolved during denop:\n", + "plt.plot(callback.gradient_norm, label=\"Gradient norm\")\n", + "plt.yscale(\"log\")\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"Gradient norm\")\n", + "plt.title(\"Density optimization convergence\")\n", + "plt.axhline(\n", + " density_optimizer.convergence_tolerance, color=\"red\", linestyle=\"--\", label=\"Tolerance\"\n", + ")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "729ad361", + "metadata": {}, + "source": [ + "Furthermore, we want to look at the Energy evolvance per iteration (remember we predict T_s + E_xc, and \"hartee\" and \"nuclear_attractio\"n energy are ajusted acording to the density or independent of it)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "708f0831", + "metadata": {}, + "outputs": [], + "source": [ + "from mldft.ofdft.ofstate import OFState\n", + "\n", + "# OFState returned by the ConvergenceCallback,\n", + "# amongst other things, it contains the predicted coefficients and energies\n", + "# at every iteration\n", + "energies_dict = {key: [] for key in energies_label.energies_dict.keys()}\n", + "for key in energies_label.energies_dict.keys():\n", + " for energy in callback.energy:\n", + " energies_dict[key].append(energy.energies_dict[key])\n", + " energies_dict[key] = np.array(energies_dict[key])\n", + "\n", + "# build the total energy from the different contributions\n", + "total_energy = np.zeros_like(next(iter(energies_dict.values())))\n", + "for key in energies_dict.keys():\n", + " total_energy += energies_dict[key]\n", + "energies_dict[\"total_energy\"] = total_energy\n", + "\n", + "# plot a curve of how the energy (as predicted by our model) evolved during denop\n", + "# kin_plus_xc is the energy that our model directly predicts\n", + "# all other energy contributions are computed from the density (with existing functionals)\n", + "# the total energy is the sum of all contributions\n", + "# important note: the total energy is therefore an approximation to the true DFT energy\n", + "# based on our learned functional (model)\n", + "for key in energies_dict.keys():\n", + " plt.plot(energies_dict[key], label=key)\n", + "plt.xlabel(\"Denop iteration\")\n", + "plt.ylabel(\"Predicted energy (mHa)\")\n", + "plt.title(\"Energy during density optimization\")\n", + "plt.legend(loc=(1.01, 0))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1324daad", + "metadata": {}, + "source": [ + "Lastly, a full overview to the denop process:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ddebce2", + "metadata": {}, + "outputs": [], + "source": [ + "# okay in the above plot we don't see much change in the energies during denop\n", + "# more interesting is to look at the difference between the predicted energies\n", + "# and the energy ground state labels\n", + "# for that, we can use the plot from plot_density_optimization\n", + "# that is also shown in the pdf of denop plots:\n", + "from mldft.ml.data.components.basis_transforms import transform_tensor_with_sample\n", + "from mldft.ml.data.components.of_data import Representation\n", + "from mldft.utils.plotting.density_optimization import plot_density_optimization\n", + "\n", + "# for comparison, transform the ground state coeffs back to the untransformed representation,\n", + "# since the trajectory coeffs are transformed back before in the callback\n", + "gs_coeffs = transform_tensor_with_sample(\n", + " sample_double, sample_double.ground_state_coeffs, Representation.VECTOR, invert=True\n", + ")\n", + "\n", + "fig = plot_density_optimization(\n", + " callback=callback,\n", + " energies_label=energies_label,\n", + " coeffs_label=gs_coeffs, # ground state density coefficients as label used for computing the density error\n", + " sample=sample_double,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "dd640b57", + "metadata": {}, + "source": [ + "In this plot, we first see the energy differences to the ground state energy in mHa. In more detail, \n", + "it oscialltes around the groundstate energy until it converges to a final energy that is slightly above the ground state energy.\n", + "\n", + "The second panel shows the error between the predicted and the target density as well as the gradient norm.\n", + "The density error is computed as the L2 norm of the difference between the predicted and target density on the grid.\n", + "The gradient norm is the norm of the gradient of the energy with respect to the density coefficients. \n", + "\n", + "The third panel shows the change in the density coefficients during denop.\n", + "\n", + "Finally, the last panel shows the dipole moment differences to the ground state dipole moment in au (atomic units).\n" + ] + }, + { + "cell_type": "markdown", + "id": "80ef612c", + "metadata": {}, + "source": [ + "In this tutorial, we have illustrated some of the inner workings behind density optimization. For many of the small steps like getting a data sample one which we can run density optimization or \n", + "there exists some high level functionality in our code base to do them (in the [run_density_optimization.py](../../mldft/ofdft/run_density_optimization.py)): \n", + "\n", + "`SampleGenerator`: a class to obain individual data samples from a full model/data config \n", + "`run_singlepoint_ofdft`: a function that runs a full density opitmization for the given molecule " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mldft", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/tutorial/tutorial_fig1.png b/notebooks/tutorial/tutorial_fig1.png new file mode 100644 index 0000000..53af7a4 Binary files /dev/null and b/notebooks/tutorial/tutorial_fig1.png differ