{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# **SchNetPack Ensemble Calculator for Atomistic Simulations**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial demonstrates how to use SchNetPack’s ensemble calculator to predict atomic energies and forces with uncertainty estimation.\n", "\n", "We’ll walk through the following examples:\n", "- How to calculate ensemble-based uncertainty\n", "- Structure relaxation using ensemble predictions and uncertainty\n", "- Running molecular dynamics (MD) simulations with uncertainty\n", "\n", "These tools are useful for identifying uncertain regions in simulation trajectories and making more informed decisions in atomistic modeling.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:31:38.016929Z", "start_time": "2025-05-09T22:31:30.976562Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "\n", "from ase import units\n", "from ase.io import read\n", "from ase.optimize.lbfgs import LBFGS\n", "from ase.md.velocitydistribution import MaxwellBoltzmannDistribution\n", "from ase.md.langevin import Langevin\n", "\n", "from schnetpack.interfaces.ase_interface import SpkEnsembleCalculator, AbsoluteUncertainty, RelativeUncertainty\n", "import schnetpack.transform as trn\n", "from schnetpack.datasets import MD17\n", "import torch\n", "\n", "\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Ensemble Interface to ASE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We specify a list of PaiNN models trained on ethanol structures from the rMD17 dataset. These models constitute the **ensemble**, which will serve as a testbed for SchNetPack’s ensemble-enabled `SpkEnsembleCalculator`.\n", "\n", "**Note**: The models have been trained on 1000 samples only." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:31:38.044782Z", "start_time": "2025-05-09T22:31:38.041032Z" } }, "outputs": [], "source": [ "model_path_list = ['../trained_models/rmd17_ethanol/painn_1/best_model',\n", " '../trained_models/rmd17_ethanol/painn_2/best_model',\n", " '../trained_models/rmd17_ethanol/painn_3/best_model',\n", " '../trained_models/rmd17_ethanol/painn_4/best_model',\n", " '../trained_models/rmd17_ethanol/painn_5/best_model']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ⚖️ Creating an Ensemble Calculator with Uncertainty Quantification\n", "\n", "In this section, we instantiate two different uncertainty estimators:\n", "- `AbsoluteUncertainty` calculates raw standard deviation values.\n", "- `RelativeUncertainty` gives uncertainty as a fraction of the mean, which helps when comparing predictions on different scales.\n", "\n", "Both uncertainty methods are bundled together in `SpkEnsembleCalculator`. This lets us evaluate uncertainty in multiple ways for the same prediction run, giving a more complete picture of model confidence.\n", "\n", "Finally, we create the `SpkEnsembleCalculator`, which uses multiple trained models to make predictions. It also estimates uncertainty using the methods we provided. This calculator will act just like a regular ASE calculator but with built-in support for ensemble averaging and uncertainty tracking.\n", "\n", "Note that you can also define custom uncertainty methods and pass them to the `SpkEnsembleCalculator`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:04.364669Z", "start_time": "2025-05-09T22:33:04.325679Z" } }, "outputs": [], "source": [ "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)\n", "uncertainty_rel = RelativeUncertainty(energy_weight=1.0, force_weight=2.0)\n", "\n", "uncertainty = [uncertainty_abs, uncertainty_rel]\n", "\n", "ensemble_calculator = SpkEnsembleCalculator(\n", " models=model_path_list,\n", " neighbor_list=trn.ASENeighborList(cutoff=5.0),\n", " energy_key=MD17.energy,\n", " force_key=MD17.forces,\n", " energy_unit=\"kcal/mol\",\n", " position_unit=\"Ang\",\n", " uncertainty_fn=uncertainty)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assign the ensemble calculator `ensemble_calculator` to the atoms object" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:10.891141Z", "start_time": "2025-05-09T22:33:10.873903Z" } }, "outputs": [], "source": [ "#load data into atoms object\n", "atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)\n", "# specify atoms calculator\n", "atoms.calc = ensemble_calculator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " 🔮 Prediction Output:\n", "- ⚡ Energy: Total potential energy of the atomic system.\n", "- 🔧 Forces: Atomic forces for optimization or molecular dynamics.\n", "- 📊 Uncertainty: Estimation of model prediction uncertainty from the ensemble." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:17.551945Z", "start_time": "2025-05-09T22:33:17.390831Z" } }, "outputs": [], "source": [ "print(\"Prediction:\")\n", "print(\"energy:\", atoms.get_total_energy())\n", "print(\"forces:\", atoms.get_forces())\n", "print(\"uncertainty:\", ensemble_calculator.get_uncertainty(atoms))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Structure Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "🏗️ **Distort molecular structure:**\n", "- A small random disturbance is added to the atomic positions.\n", "- This simulates a noisy or slightly perturbed structure, which helps us test how sensitive the ensemble model is to input changes.\n", "- It also makes the uncertainty values more meaningful by introducing some variability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:21.723617Z", "start_time": "2025-05-09T22:33:21.721234Z" } }, "outputs": [], "source": [ "# distort the structure\n", "atoms.positions += np.random.normal(0, 0.1, atoms.positions.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "⚙️ **Set Up Optimization:**\n", "- 🧑‍🔬 **Optimizer**: We set up the LBFGS optimizer, which is a gradient-based method used to minimize the energy of the atomic system.\n", "- 🔬 **Calculator**: Assign the ensemble calculator to the atoms object, this connects the prediction engine (our ensemble of models) to the atomic system." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:22.818549Z", "start_time": "2025-05-09T22:33:22.815633Z" } }, "outputs": [], "source": [ "optimizer = LBFGS(atoms)\n", "atoms.calc = ensemble_calculator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "🔄 **Optimization Loop with Uncertainty Tracking:**\n", "- 🧑‍🔬 **Optimizer**: Run the optimization using the LBFGS algorithm with a force tolerance of 0.01 and a maximum of 100 steps.\n", "- 📊 **Uncertainty Tracking**: After each optimization step, the uncertainty of the energy prediction is appended to the `uncertainties` list, providing insight into model confidence during the process." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:27.909342Z", "start_time": "2025-05-09T22:33:26.805029Z" } }, "outputs": [], "source": [ "uncertainties = []\n", "\n", "for _ in optimizer.irun(fmax=0.05, steps=300):\n", " uncertainties.append(ensemble_calculator.get_uncertainty(atoms))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we're using an ensemble of models, we can now estimate the uncertainty in our predictions during the optimization process:\n", "- **Absolute** and **Relative Uncertainty** values are extracted from the optimization steps.\n", "- Plot both **absolute** and **relative** uncertainties against the optimization steps to visualize how uncertainty changes during the process." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:33:46.653285Z", "start_time": "2025-05-09T22:33:46.434122Z" } }, "outputs": [], "source": [ "# Extract individual uncertainty types\n", "abs_vals = [d[\"AbsoluteUncertainty\"] for d in uncertainties]\n", "rel_vals = [d[\"RelativeUncertainty\"] for d in uncertainties]\n", "steps = list(range(len(uncertainties)))\n", "\n", "# Create figure and first axis\n", "fig, ax1 = plt.subplots(figsize=(8, 6))\n", "\n", "# Plot absolute uncertainty on left y-axis\n", "ax1.plot(steps, abs_vals, label=\"Absolute Uncertainty\", marker='o', color='tab:blue')\n", "ax1.set_xlabel(\"Optimization Step\")\n", "ax1.set_ylabel(\"Absolute Uncertainty\", color='tab:blue')\n", "ax1.tick_params(axis='y', labelcolor='tab:blue')\n", "ax1.grid(True)\n", "\n", "# Create second y-axis for relative uncertainty\n", "ax2 = ax1.twinx()\n", "ax2.plot(steps, rel_vals, label=\"Relative Uncertainty\", marker='x', color='tab:red')\n", "ax2.set_ylabel(\"Relative Uncertainty\", color='tab:red')\n", "ax2.tick_params(axis='y', labelcolor='tab:red')\n", "\n", "# Title and layout\n", "plt.title(\"Uncertainty during Optimization\")\n", "fig.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While the absolute uncertainty rapidly decreases and remains consistently low and stable thereafter, the relative uncertainty increases as the structure optimization converges. This rise in relative uncertainty is due to the diminishing force magnitudes: although the prediction uncertainty stays nearly constant, the mean predicted values become very small, leading to a larger ratio between uncertainty and prediction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Molecular Dynamics With Increasing Temperature" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now investigate the behavior of the uncertainty measure during a MD simulation. To this end, we perform a simulation in the canonical ensemble (NVT), gradually increasing the temperature of the heat bath throughout the run. As the temperature rises, we expect larger deviations of the molecular structure from equilibrium configurations. Consequently, the system is more likely to sample structures that lie outside the training distribution of the machine learning force field. This effect is reflected in the absolute uncertainty measure, which increases with temperature.\n", "\n", "In this setup, we use only absolute uncertainty to measure how much the model predictions vary across the ensemble:\n", "\n", "🔎 **Note**: The `uncertainty_fn` can be passed as either a **single** uncertainty function or as a **list** of uncertainty functions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:34:01.039825Z", "start_time": "2025-05-09T22:34:00.840244Z" } }, "outputs": [], "source": [ "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)\n", "\n", "abs_ensemble_calculator = SpkEnsembleCalculator(\n", " models=model_path_list,\n", " neighbor_list=trn.ASENeighborList(cutoff=5.0),\n", " energy_key=MD17.energy,\n", " force_key=MD17.forces,\n", " energy_unit=\"kcal/mol\",\n", " position_unit=\"Ang\",\n", " uncertainty_fn=uncertainty_abs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2025-05-09T22:36:41.958945Z", "start_time": "2025-05-09T22:34:05.121431Z" } }, "outputs": [], "source": [ "target_temperatures = [_ for _ in range(50, 800, 100)] \n", "n_steps = 1000 \n", "sampling_interval = 10 \n", "step_size = 0.5 \n", "\n", "# setting up initial atoms\n", "atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)\n", "atoms.calc = abs_ensemble_calculator\n", "\n", "MaxwellBoltzmannDistribution(atoms, temperature_K=target_temperatures[0])\n", "\n", "ats_traj = []\n", "uncertainties = []\n", "temp = []\n", "\n", "for target_temperature in target_temperatures:\n", " print(f\"Temp: {target_temperature:.2f} K\")\n", " for step in tqdm(range(n_steps // sampling_interval)):\n", " \n", " dyn = Langevin(\n", " atoms, \n", " timestep=step_size * units.fs, \n", " temperature_K=target_temperature,\n", " friction=0.01 / units.fs\n", " )\n", " \n", " dyn.run(sampling_interval)\n", " \n", " temp.append(atoms.get_temperature())\n", " uncertainties.append(abs_ensemble_calculator.get_uncertainty(atoms))\n", " \n", " ats_traj.append(atoms.copy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax1 = plt.subplots(figsize=(8, 6))\n", "\n", "ax1.plot(uncertainties, marker='o', color='blue', label='Uncertainty')\n", "ax1.set_xlabel(\"MD Step\")\n", "ax1.set_ylabel(\"Uncertainty\", color='blue')\n", "ax1.tick_params(axis='y', labelcolor='blue')\n", "\n", "ax2 = ax1.twinx()\n", "ax2.plot(temp, marker='x', color='red', label='Temperature')\n", "ax2.set_ylabel(\"Temperature (K)\", color='red')\n", "ax2.tick_params(axis='y', labelcolor='red')\n", "\n", "plt.title(\"Molecular Dynamics: Uncertainty and Temperature Profile\")\n", "ax1.grid(True)\n", "\n", "lines_1, labels_1 = ax1.get_legend_handles_labels()\n", "lines_2, labels_2 = ax2.get_legend_handles_labels()\n", "ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize the MD trajectory of the structure to make sure that nothing went wrong" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ase.visualize import view\n", "view(ats_traj)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 4 }