Source code for md.calculators.lj_calculator

from __future__ import annotations
from typing import TYPE_CHECKING, Union, Dict

if TYPE_CHECKING:
    from schnetpack.md.neighborlist_md import NeighborListMD

import torch
import torch.nn as nn

from schnetpack.md.calculators import SchNetPackCalculator

from schnetpack import properties
import schnetpack.nn as snn
from schnetpack.atomistic import Forces, PairwiseDistances, Strain

__all__ = ["LJCalculator", "LJModel"]


[docs]class LJCalculator(SchNetPackCalculator): """ Lennard Jones potential calculator. Primarily used for testing barostats and thermostats. Args: r_equilibrium (float): equilibrium distance in position units well_depth (float): depth of the Lennard-Jones potential in energy units. force_key (str): String indicating the entry corresponding to the molecular forces energy_unit (float, float): Conversion factor converting the energies returned by the used model back to internal MD units. position_unit (float, float): Conversion factor for converting the system positions to the units required by the model. neighbor_list (schnetpack.md.neighbor_list.MDNeighborList): Neighbor list object for determining which interatomic distances should be computed. energy_key (str, optional): If provided, label is used to store the energies returned by the model to the system. stress_key (str, optional): If provided, label is used to store the stress returned by the model to the system (required for constant pressure simulations). property_conversion (dict(float)): Optional dictionary of conversion factors for other properties predicted by the model. Only changes the units used for logging the various outputs. healing_length (float): Healing length used for the cutoff potential. """ def __init__( self, r_equilibrium: float, well_depth: float, force_key: str, energy_unit: Union[str, float], position_unit: Union[str, float], neighbor_list: NeighborListMD, energy_key: str = None, stress_key: str = None, property_conversion: Dict[str, Union[str, float]] = {}, healing_length: float = 3.0, ): model = LJModel( r_equilibrium=r_equilibrium, well_depth=well_depth, cutoff=neighbor_list.cutoff, healing_length=healing_length, calc_forces=True, calc_stress=(stress_key is not None), energy_key=energy_key, force_key=force_key, stress_key=stress_key, ) super(LJCalculator, self).__init__( model, force_key=force_key, energy_unit=energy_unit, position_unit=position_unit, neighbor_list=neighbor_list, energy_key=energy_key, stress_key=stress_key, property_conversion=property_conversion, ) def _prepare_model(self, model_file: LJModel): """ Dummy routine, since no model has to be loaded. Args: model_file (LJModel): Initialized Lennard-Jones model. Returns: LJModel: input model. """ return model_file
class LJModel(nn.Module): """ Lennard Jones potential calculator. Primarily used for testing barostats and thermostats. Args: r_equilibrium (float): equilibrium distance in position units well_depth (float): depth of the Lennard-Jones potential in energy units. cutoff (float): Cutoff radius for computing the neighbor interactions. If this is set to a negative number, the cutoff is determined automatically based on the model (default=-1.0). Units are the distance units used in the model. healing_length (float): Healing length used for the cutoff potential. calc_forces (bool): toggle force computation. calc_stress (bool): toggle stress computation. energy_key (str): Key used for storing energies. force_key (str): Key used for storing forces. stress_key (str): Key used for storing stress. """ def __init__( self, r_equilibrium: float, well_depth: float, cutoff: float, healing_length: float, calc_forces: bool = True, calc_stress: bool = True, energy_key: str = properties.energy, force_key: str = properties.forces, stress_key: str = properties.stress, ): super(LJModel, self).__init__() self.r_equilibrium = r_equilibrium self.well_depth = well_depth self.calc_forces = calc_forces self.calc_stress = calc_stress self.energy_key = energy_key self.force_key = force_key self.stress_key = stress_key self.cutoff_function = CustomCutoff( cutoff_radius=cutoff, healing_length=healing_length ) # Modules for distances, stress tensor and forces self.distances = PairwiseDistances() self.strain = Strain() self.force_layer = Forces( calc_forces=calc_forces, calc_stress=calc_stress, energy_key=energy_key, force_key=force_key, stress_key=stress_key, ) def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Compute the Lennard-Jones energy and forces if requested. Args: inputs (dict(str, torch.Tensor)): Input dictionary. Returns: dict(str, torch.Tensor): Dictionary of model outputs. """ # Activate gradient for force/stress computations if self.calc_forces: inputs[properties.R].requires_grad_() if self.calc_stress: inputs = self.strain(inputs) # Compute interatomic distances inputs = self.distances(inputs) vec_ij = inputs[properties.Rij] positions = inputs[properties.R] idx_i = inputs[properties.idx_i] r_ij = torch.norm(vec_ij, dim=1, keepdim=True) r_cut = self.cutoff_function(r_ij) # Compute lennard jones potential power_6 = torch.pow(self.r_equilibrium / r_ij, 6) power_12 = power_6 * power_6 yij = (power_12 - power_6) * r_cut # aggregate yi = snn.scatter_add(yij, idx_i, dim_size=positions.shape[0]) idx_m = inputs[properties.idx_m] maxm = int(idx_m[-1]) + 1 tmp = torch.zeros((maxm, 1), dtype=yi.dtype, device=yi.device) y = tmp.index_add(0, idx_m, yi) y = torch.squeeze(y, -1) y = self.well_depth * 0.5 * y # collect results inputs[self.energy_key] = y # Compute forces and stress inputs.update(self.force_layer(inputs)) results = {self.energy_key: y} if self.calc_forces: results[self.force_key] = inputs[self.force_key].detach() if self.calc_stress: results[self.stress_key] = inputs[self.stress_key].detach() return results class CustomCutoff(nn.Module): def __init__(self, cutoff_radius: float, healing_length: float): """ Custom cutoff for Lennard-Jones potentials using a healing length. Args: cutoff_radius (float): cutoff radius. healing_length (float): healing length. """ super(CustomCutoff, self).__init__() self.register_buffer("cutoff_radius", torch.Tensor([cutoff_radius])) self.register_buffer("healing_length", torch.Tensor([healing_length])) def forward(self, distances: torch.Tensor) -> torch.Tensor: """ Compute cutoff based on the input distances. Args: distances (torch.tensor): Returns: torch.tensor: cutoff function applied to the distances. """ # Compute basic component r = ( distances - (self.cutoff_radius - self.healing_length) ) / self.healing_length r_function = 1.0 + r**2 * (2.0 * r - 3.0) # Compute second part of cutoff switch = torch.where( distances > self.cutoff_radius - self.healing_length, r_function, torch.ones_like(distances), ) # Compute third component switch = torch.where( distances > self.cutoff_radius, torch.zeros_like(distances), switch ) return switch