{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Training a neural network on QM9\n", "\n", "This tutorial will explain how to use SchNetPack for training a model\n", "on the QM9 dataset and how the trained model can be used for further applications.\n", "\n", "First, we import the necessary modules and create a new directory for the data and our model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import os\n", "import schnetpack as spk\n", "from schnetpack.datasets import QM9\n", "import schnetpack.transform as trn\n", "\n", "import torch\n", "import torchmetrics\n", "import pytorch_lightning as pl\n", "\n", "qm9tut = \"./qm9tut\"\n", "if not os.path.exists(\"qm9tut\"):\n", " os.makedirs(qm9tut)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Loading the data\n", "\n", "As explained in the [previous tutorial](tutorial_01_preparing_data.ipynb), datasets in SchNetPack are loaded with the `AtomsLoader` class or one of the sub-classes that are specialized for common benchmark datasets. \n", "The `QM9` dataset class will download and convert the data. We will only use the inner energy at 0K `U0`, so all other properties do not need to be loaded:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "%rm split.npz\n", "\n", "qm9data = QM9(\n", " \"./qm9.db\",\n", " batch_size=100,\n", " num_train=1000,\n", " num_val=1000,\n", " transforms=[\n", " trn.ASENeighborList(cutoff=5.0),\n", " trn.RemoveOffsets(QM9.U0, remove_mean=True, remove_atomrefs=True),\n", " trn.CastTo32(),\n", " ],\n", " property_units={QM9.U0: \"eV\"},\n", " num_workers=1,\n", " split_file=os.path.join(qm9tut, \"split.npz\"),\n", " pin_memory=True, # set to false, when not using a GPU\n", " load_properties=[QM9.U0], # only load U0 property\n", ")\n", "qm9data.prepare_data()\n", "qm9data.setup()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The dataset is downloaded and partitioned automatically. PyTorch `DataLoader`s can be obtained using `qm9data.train_dataloader()`, `qm9data.val_dataloader()` and `qm9data.test_dataloader()`.\n", "\n", "Before building the model, we remove offsets from the energy for good initial conditions. We will get this from the training dataset. Above, this is done automatically by the `RemoveOffsets` transform.\n", "In the following we show what happens under the hood.\n", "For QM9, we also have single-atom reference values stored in the metadata:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "atomrefs = qm9data.train_dataset.atomrefs\n", "print(\"U0 of hyrogen:\", atomrefs[QM9.U0][1].item(), \"eV\")\n", "print(\"U0 of carbon:\", atomrefs[QM9.U0][6].item(), \"eV\")\n", "print(\"U0 of oxygen:\", atomrefs[QM9.U0][8].item(), \"eV\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "These can be used together with the mean and standard deviation of the energy per atom to initialize the model with a good guess of the energy of a molecule. When calculating these statistics, we pass the atomref to take into account, that the model will add these atomrefs to the predicted energy later, so that this part of the energy does not have to be considered in the statistics, i.e.\n", "\\begin{equation}\n", "\\mu_{U_0} = \\frac{1}{n_\\text{train}} \\sum_{n=1}^{n_\\text{train}} \\left( U_{0,n} - \\sum_{i=1}^{n_{\\text{atoms},n}} U_{0,Z_{n,i}} \\right)\n", "\\end{equation}\n", "for the mean and analogously for the standard deviation. In this case, this corresponds to the mean and std. dev of the *atomization energy* per atom." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "means, stddevs = qm9data.get_stats(QM9.U0, divide_by_atoms=True, remove_atomref=True)\n", "print(\"Mean atomization energy / atom:\", means.item())\n", "print(\"Std. dev. atomization energy / atom:\", stddevs.item())" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Setting up the model\n", "\n", "Next, we need to build the model and define how it should be trained.\n", "\n", "In SchNetPack, a neural network potential usually consists of three parts:\n", "\n", "1. A list of input modules that prepare the batched data before the building the representation.\n", " This includes, e.g., the calculation of pairwise distances between atoms based on neighbor indices or add auxiliary\n", " inputs for response properties.\n", "2. The representation which either constructs atom-wise features, e.g. with SchNet or PaiNN.\n", "3. One or more output modules for property prediction.\n", "\n", "Here, we use the `SchNet` representation with 3 interaction layers, a 5 Angstrom cosine cutoff with pairwise distances\n", "expanded on 20 Gaussians and 30 atomwise features and convolution filters, since we only have a few\n", "training examples. Then, we use an `Atomwise` module to predict the inner energy $U_0$ by summing over atom-wise\n", "energy contributions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "cutoff = 5.0\n", "n_atom_basis = 30\n", "\n", "pairwise_distance = (\n", " spk.atomistic.PairwiseDistances()\n", ") # calculates pairwise distances between atoms\n", "radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)\n", "schnet = spk.representation.SchNet(\n", " n_atom_basis=n_atom_basis,\n", " n_interactions=3,\n", " radial_basis=radial_basis,\n", " cutoff_fn=spk.nn.CosineCutoff(cutoff),\n", ")\n", "pred_U0 = spk.atomistic.Atomwise(n_in=n_atom_basis, output_key=QM9.U0)\n", "\n", "nnpot = spk.model.NeuralNetworkPotential(\n", " representation=schnet,\n", " input_modules=[pairwise_distance],\n", " output_modules=[pred_U0],\n", " postprocessors=[\n", " trn.CastTo64(),\n", " trn.AddOffsets(QM9.U0, add_mean=True, add_atomrefs=True),\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The last argument here is a list of postprocessors that will only be used if `nnpot.inference_mode=True` is set.\n", "It will not be used in training or validation, but only for predictions.\n", "Here, this is used to deal with numerical accuracy and normalization of model outputs:\n", "To make training easier, we have subtracted single atom energies as well as the mean energy per atom\n", "in the preprocessing (see above).\n", "This does not matter for the loss, but for the final prediction we want to get the real energies.\n", "Additionally, we have removed the energy offsets *before* casting to float32 in the preprocessor.\n", "This avoids loss of numerical precision.\n", "Analog to this, we also have to first cast to float64, before re-adding the offsets in the post-processor\n", "\n", "The output modules store the prediction in a dictionary under the `output_key` (here: `QM9.U0`), which is connected to\n", "a target property with loss functions and evaluation metrics using the `ModelOutput` class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "output_U0 = spk.task.ModelOutput(\n", " name=QM9.U0,\n", " loss_fn=torch.nn.MSELoss(),\n", " loss_weight=1.0,\n", " metrics={\"MAE\": torchmetrics.MeanAbsoluteError()},\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "By default, the target is assumed to have the same name as the output. Otherwise, a different `target_name`\n", "has to be provided.\n", "Here, we already gave the output the same name as the target in the dataset (`QM9.U0`).\n", "In case of multiple outputs, the full loss is a weighted sum of all output losses.\n", "Therefore, it is possible to provide a `loss_weight`, which we here just set to 1.\n", "\n", "All components defined above are then passed to `AtomisticTask`, which is a sublass of\n", "[`LightningModule`](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html).\n", "This connects the model and training process and can then be passed to the PyTorch Lightning `Trainer`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "task = spk.task.AtomisticTask(\n", " model=nnpot,\n", " outputs=[output_U0],\n", " optimizer_cls=torch.optim.AdamW,\n", " optimizer_args={\"lr\": 1e-4},\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Training the model\n", "\n", "Now, the model is ready for training. Since we already defined all necessary components, the only thing left to do is\n", "passing it to the PyTorch Lightning `Trainer` together with the data module.\n", "\n", "Additionally, we can provide callbacks that take care of logging, checkpointing etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "logger = pl.loggers.TensorBoardLogger(save_dir=qm9tut)\n", "callbacks = [\n", " spk.train.ModelCheckpoint(\n", " model_path=os.path.join(qm9tut, \"best_inference_model\"),\n", " save_top_k=1,\n", " monitor=\"val_loss\",\n", " )\n", "]\n", "\n", "trainer = pl.Trainer(\n", " callbacks=callbacks,\n", " logger=logger,\n", " default_root_dir=qm9tut,\n", " max_epochs=3, # for testing, we restrict the number of epochs\n", ")\n", "trainer.fit(task, datamodule=qm9data)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The `ModelCheckpoint` of SchNetPack is equivalent to that in PyTorch Lightning,\n", "except that we also store the best inference model. We will show how to use this in the next section.\n", "\n", "You can have a look at the training log using Tensorboard:\n", "```\n", "tensorboard --logdir=qm9tut/lightning_logs\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Inference\n", "\n", "Having trained a model for QM9, we are going to use it to obtain some predictions.\n", "First, we need to load the model. The `Trainer` stores the best model in the model directory which can be loaded using PyTorch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from ase import Atoms\n", "from schnetpack.utils.compatibility import load_model\n", "\n", "best_model = load_model(\n", " os.path.join(qm9tut, \"best_inference_model\"), device=\"cpu\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "We can use the test dataloader from the QM( data to obtain a batch of molecules and apply the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "for batch in qm9data.test_dataloader():\n", " result = best_model(batch)\n", " print(\"Result dictionary:\", result)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "If your data is not already in SchNetPack format, a convenient way is to use ASE atoms with the\n", "provided `AtomsConverter`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "converter = spk.interfaces.AtomsConverter(\n", " neighbor_list=trn.ASENeighborList(cutoff=5.0), dtype=torch.float32\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "numbers = np.array([6, 1, 1, 1, 1])\n", "positions = np.array(\n", " [\n", " [-0.0126981359, 1.0858041578, 0.0080009958],\n", " [0.002150416, -0.0060313176, 0.0019761204],\n", " [1.0117308433, 1.4637511618, 0.0002765748],\n", " [-0.540815069, 1.4475266138, -0.8766437152],\n", " [-0.5238136345, 1.4379326443, 0.9063972942],\n", " ]\n", ")\n", "atoms = Atoms(numbers=numbers, positions=positions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "inputs = converter(atoms)\n", "\n", "print(\"Keys:\", list(inputs.keys()))\n", "\n", "pred = best_model(inputs)\n", "\n", "print(\"Prediction:\", pred[QM9.U0])" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Alternatively, one can use the `SpkCalculator` as an interface to ASE. The calculator requires the path to a trained model and a neighborlist as input. In addition, the names and units of properties used in the model (e.g. the energy) should be provided. Precision and device can be set via the `dtype` and `device` keywords:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "calculator = spk.interfaces.SpkCalculator(\n", " model=os.path.join(qm9tut, \"best_inference_model\"), # path to model\n", " neighbor_list=trn.ASENeighborList(cutoff=5.0), # neighbor list\n", " energy_key=QM9.U0, # name of energy property in model\n", " force_key=None,\n", " energy_unit=\"eV\", # units of energy property\n", " device=\"cpu\", # device for computation\n", ")\n", "atoms.calc = calculator\n", "print(\"Prediction:\", atoms.get_total_energy())" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The calculator automatically converts the prediction of the given unit to internal ASE units, which is `eV`\n", "for the energy.\n", "Using the calculator interface makes more sense if you have trained SchNet for a potential energy surface.\n", "In the next tutorials, we will show how to learn potential energy surfaces and forces field as well as performing\n", "molecular dynamics simulations with SchNetPack." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 4 }