SchNetPack Ensemble Calculator for Atomistic Simulations
This tutorial demonstrates how to use SchNetPack’s ensemble calculator to predict atomic energies and forces with uncertainty estimation.
We’ll walk through the following examples:
How to calculate ensemble-based uncertainty
Structure relaxation using ensemble predictions and uncertainty
Running molecular dynamics (MD) simulations with uncertainty
These tools are useful for identifying uncertain regions in simulation trajectories and making more informed decisions in atomistic modeling.
[1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from ase import units
from ase.io import read
from ase.optimize.lbfgs import LBFGS
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.langevin import Langevin
from schnetpack.interfaces.ase_interface import SpkEnsembleCalculator, AbsoluteUncertainty, RelativeUncertainty
import schnetpack.transform as trn
from schnetpack.datasets import MD17
import torch
np.random.seed(42)
Ensemble Interface to ASE
We specify a list of PaiNN models trained on ethanol structures from the rMD17 dataset. These models constitute the ensemble, which will serve as a testbed for SchNetPack’s ensemble-enabled SpkEnsembleCalculator.
Note: The models have been trained on 1000 samples only.
[2]:
model_path_list = ['../trained_models/rmd17_ethanol/painn_1/best_model',
'../trained_models/rmd17_ethanol/painn_2/best_model',
'../trained_models/rmd17_ethanol/painn_3/best_model',
'../trained_models/rmd17_ethanol/painn_4/best_model',
'../trained_models/rmd17_ethanol/painn_5/best_model']
⚖️ Creating an Ensemble Calculator with Uncertainty Quantification
In this section, we instantiate two different uncertainty estimators:
AbsoluteUncertaintycalculates raw standard deviation values.RelativeUncertaintygives uncertainty as a fraction of the mean, which helps when comparing predictions on different scales.
Both uncertainty methods are bundled together in SpkEnsembleCalculator. This lets us evaluate uncertainty in multiple ways for the same prediction run, giving a more complete picture of model confidence.
Finally, we create the SpkEnsembleCalculator, which uses multiple trained models to make predictions. It also estimates uncertainty using the methods we provided. This calculator will act just like a regular ASE calculator but with built-in support for ensemble averaging and uncertainty tracking.
Note that you can also define custom uncertainty methods and pass them to the SpkEnsembleCalculator.
[3]:
uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)
uncertainty_rel = RelativeUncertainty(energy_weight=1.0, force_weight=2.0)
uncertainty = [uncertainty_abs, uncertainty_rel]
ensemble_calculator = SpkEnsembleCalculator(
models=model_path_list,
neighbor_list=trn.ASENeighborList(cutoff=5.0),
energy_key=MD17.energy,
force_key=MD17.forces,
energy_unit="kcal/mol",
position_unit="Ang",
uncertainty_fn=uncertainty)
Assign the ensemble calculator ensemble_calculator to the atoms object
[4]:
#load data into atoms object
atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)
# specify atoms calculator
atoms.calc = ensemble_calculator
🔮 Prediction Output:
⚡ Energy: Total potential energy of the atomic system.
🔧 Forces: Atomic forces for optimization or molecular dynamics.
📊 Uncertainty: Estimation of model prediction uncertainty from the ensemble.
[5]:
print("Prediction:")
print("energy:", atoms.get_total_energy())
print("forces:", atoms.get_forces())
print("uncertainty:", ensemble_calculator.get_uncertainty(atoms))
Prediction:
energy: -4210.131428631634
forces: [[ 0.70392037 0.08519621 0.08806095]
[-0.47793492 -0.45186286 -0.44601064]
[-0.24903458 0.53124927 0.51912287]
[-0.31574337 -0.7336392 0.20814634]
[-0.30961297 0.1954594 -0.73971642]
[ 0.30894925 -0.11190948 0.97377769]
[ 0.31608053 0.96932102 -0.12929553]
[-1.33278014 -0.46331379 -0.46274125]
[ 1.35615576 -0.02050054 -0.01134399]]
uncertainty: {'AbsoluteUncertainty': 0.003303904209174205, 'RelativeUncertainty': 0.003354954798337133}
Structure Optimization
🏗️ Distort molecular structure:
A small random disturbance is added to the atomic positions.
This simulates a noisy or slightly perturbed structure, which helps us test how sensitive the ensemble model is to input changes.
It also makes the uncertainty values more meaningful by introducing some variability.
[6]:
# distort the structure
atoms.positions += np.random.normal(0, 0.1, atoms.positions.shape)
⚙️ Set Up Optimization:
🧑🔬 Optimizer: We set up the LBFGS optimizer, which is a gradient-based method used to minimize the energy of the atomic system.
🔬 Calculator: Assign the ensemble calculator to the atoms object, this connects the prediction engine (our ensemble of models) to the atomic system.
[7]:
optimizer = LBFGS(atoms)
atoms.calc = ensemble_calculator
🔄 Optimization Loop with Uncertainty Tracking:
🧑🔬 Optimizer: Run the optimization using the LBFGS algorithm with a force tolerance of 0.01 and a maximum of 100 steps.
📊 Uncertainty Tracking: After each optimization step, the uncertainty of the energy prediction is appended to the
uncertaintieslist, providing insight into model confidence during the process.
[8]:
uncertainties = []
for _ in optimizer.irun(fmax=0.05, steps=300):
uncertainties.append(ensemble_calculator.get_uncertainty(atoms))
Step Time Energy fmax
LBFGS: 0 15:06:54 -4207.764579 10.558269
LBFGS: 1 15:06:54 -4209.616615 4.761182
LBFGS: 2 15:06:54 -4209.959119 1.813286
LBFGS: 3 15:06:54 -4210.068999 0.807915
LBFGS: 4 15:06:54 -4210.120258 0.611916
LBFGS: 5 15:06:54 -4210.163846 0.663199
LBFGS: 6 15:06:54 -4210.186417 0.450881
LBFGS: 7 15:06:54 -4210.195588 0.299218
LBFGS: 8 15:06:54 -4210.202498 0.225450
LBFGS: 9 15:06:54 -4210.207033 0.185857
LBFGS: 10 15:06:54 -4210.209379 0.113195
LBFGS: 11 15:06:54 -4210.210390 0.094224
LBFGS: 12 15:06:54 -4210.210921 0.064197
LBFGS: 13 15:06:54 -4210.211318 0.058944
LBFGS: 14 15:06:54 -4210.211717 0.065913
LBFGS: 15 15:06:54 -4210.212182 0.087708
LBFGS: 16 15:06:54 -4210.212657 0.071320
LBFGS: 17 15:06:54 -4210.213108 0.073163
LBFGS: 18 15:06:54 -4210.213617 0.069442
LBFGS: 19 15:06:54 -4210.214310 0.094219
LBFGS: 20 15:06:54 -4210.215213 0.121120
LBFGS: 21 15:06:54 -4210.216125 0.094519
LBFGS: 22 15:06:55 -4210.216759 0.069712
LBFGS: 23 15:06:55 -4210.217136 0.048263
Since we’re using an ensemble of models, we can now estimate the uncertainty in our predictions during the optimization process:
Absolute and Relative Uncertainty values are extracted from the optimization steps.
Plot both absolute and relative uncertainties against the optimization steps to visualize how uncertainty changes during the process.
[9]:
# Extract individual uncertainty types
abs_vals = [d["AbsoluteUncertainty"] for d in uncertainties]
rel_vals = [d["RelativeUncertainty"] for d in uncertainties]
steps = list(range(len(uncertainties)))
# Create figure and first axis
fig, ax1 = plt.subplots(figsize=(8, 6))
# Plot absolute uncertainty on left y-axis
ax1.plot(steps, abs_vals, label="Absolute Uncertainty", marker='o', color='tab:blue')
ax1.set_xlabel("Optimization Step")
ax1.set_ylabel("Absolute Uncertainty", color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax1.grid(True)
# Create second y-axis for relative uncertainty
ax2 = ax1.twinx()
ax2.plot(steps, rel_vals, label="Relative Uncertainty", marker='x', color='tab:red')
ax2.set_ylabel("Relative Uncertainty", color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')
# Title and layout
plt.title("Uncertainty during Optimization")
fig.tight_layout()
plt.show()
While the absolute uncertainty rapidly decreases and remains consistently low and stable thereafter, the relative uncertainty increases as the structure optimization converges. This rise in relative uncertainty is due to the diminishing force magnitudes: although the prediction uncertainty stays nearly constant, the mean predicted values become very small, leading to a larger ratio between uncertainty and prediction.
Molecular Dynamics With Increasing Temperature
We now investigate the behavior of the uncertainty measure during a MD simulation. To this end, we perform a simulation in the canonical ensemble (NVT), gradually increasing the temperature of the heat bath throughout the run. As the temperature rises, we expect larger deviations of the molecular structure from equilibrium configurations. Consequently, the system is more likely to sample structures that lie outside the training distribution of the machine learning force field. This effect is reflected in the absolute uncertainty measure, which increases with temperature.
In this setup, we use only absolute uncertainty to measure how much the model predictions vary across the ensemble:
🔎 Note: The uncertainty_fn can be passed as either a single uncertainty function or as a list of uncertainty functions.
[10]:
uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)
abs_ensemble_calculator = SpkEnsembleCalculator(
models=model_path_list,
neighbor_list=trn.ASENeighborList(cutoff=5.0),
energy_key=MD17.energy,
force_key=MD17.forces,
energy_unit="kcal/mol",
position_unit="Ang",
uncertainty_fn=uncertainty_abs)
[11]:
target_temperatures = [_ for _ in range(50, 800, 100)]
n_steps = 1000
sampling_interval = 10
step_size = 0.5
# setting up initial atoms
atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)
atoms.calc = abs_ensemble_calculator
MaxwellBoltzmannDistribution(atoms, temperature_K=target_temperatures[0])
ats_traj = []
uncertainties = []
temp = []
for target_temperature in target_temperatures:
print(f"Temp: {target_temperature:.2f} K")
for step in tqdm(range(n_steps // sampling_interval)):
dyn = Langevin(
atoms,
timestep=step_size * units.fs,
temperature_K=target_temperature,
friction=0.01 / units.fs
)
dyn.run(sampling_interval)
temp.append(atoms.get_temperature())
uncertainties.append(abs_ensemble_calculator.get_uncertainty(atoms))
ats_traj.append(atoms.copy())
Temp: 50.00 K
100%|██████████| 100/100 [00:20<00:00, 4.90it/s]
Temp: 150.00 K
100%|██████████| 100/100 [00:20<00:00, 4.91it/s]
Temp: 250.00 K
100%|██████████| 100/100 [00:20<00:00, 4.88it/s]
Temp: 350.00 K
100%|██████████| 100/100 [00:20<00:00, 4.89it/s]
Temp: 450.00 K
100%|██████████| 100/100 [00:20<00:00, 4.90it/s]
Temp: 550.00 K
100%|██████████| 100/100 [00:20<00:00, 4.96it/s]
Temp: 650.00 K
100%|██████████| 100/100 [00:20<00:00, 4.99it/s]
Temp: 750.00 K
100%|██████████| 100/100 [00:20<00:00, 4.97it/s]
[12]:
fig, ax1 = plt.subplots(figsize=(8, 6))
ax1.plot(uncertainties, marker='o', color='blue', label='Uncertainty')
ax1.set_xlabel("MD Step")
ax1.set_ylabel("Uncertainty", color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax2 = ax1.twinx()
ax2.plot(temp, marker='x', color='red', label='Temperature')
ax2.set_ylabel("Temperature (K)", color='red')
ax2.tick_params(axis='y', labelcolor='red')
plt.title("Molecular Dynamics: Uncertainty and Temperature Profile")
ax1.grid(True)
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')
plt.tight_layout()
plt.show()
Let’s visualize the MD trajectory of the structure to make sure that nothing went wrong
[13]:
from ase.visualize import view
view(ats_traj)
[13]:
<Popen: returncode: None args: ['/home/docs/checkouts/readthedocs.org/user_b...>