Source code for atomistic.nuclear_repulsion

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Callable, Dict, Optional

import schnetpack.properties as properties
import schnetpack.nn as snn
import schnetpack.units as spk_units

__all__ = ["ZBLRepulsionEnergy"]


[docs]class ZBLRepulsionEnergy(nn.Module): """ Computes a Ziegler-Biersack-Littmark style repulsion energy Args: energy_unit (str/float): Energy unit. position_unit (str/float): Unit used for distances. output_key (str): Key to which results will be stored trainable (bool): If set to true, ZBL parameters will be optimized during training (default=True) cutoff_fn (Callable): Apply a cutoff function to the interatomic distances. References: .. [#Cutoff] Ebert, D. S.; Musgrave, F. K.; Peachey, D.; Perlin, K.; Worley, S. Texturing & Modeling: A Procedural Approach; Morgan Kaufmann, 2003 .. [#ZBL] https://docs.lammps.org/pair_zbl.html """ def __init__( self, energy_unit: Union[str, float], position_unit: Union[str, float], output_key: str, trainable: bool = True, cutoff_fn: Optional[Callable] = None, ): super(ZBLRepulsionEnergy, self).__init__() energy_units = spk_units.convert_units("Ha", energy_unit) position_units = spk_units.convert_units("Bohr", position_unit) ke = energy_units * position_units self.register_buffer("ke", torch.tensor(ke)) self.cutoff_fn = cutoff_fn self.output_key = output_key # Basic ZBL parameters (in atomic units) # Since all quantities have a predefined sign, they are initialized to the inverse softplus and a softplus # function is applied in the forward pass to guarantee the correct sign during training a_div = snn.softplus_inverse( torch.tensor([1.0 / (position_units * 0.8854)]) ) # in this way, distances can be used directly a_pow = snn.softplus_inverse(torch.tensor([0.23])) exponents = snn.softplus_inverse( torch.tensor([3.19980, 0.94229, 0.40290, 0.20162]) ) coefficients = snn.softplus_inverse( torch.tensor([0.18175, 0.50986, 0.28022, 0.02817]) ) # Initialize network parameters self.a_pow = nn.Parameter(a_pow, requires_grad=trainable) self.a_div = nn.Parameter(a_div, requires_grad=trainable) self.coefficients = nn.Parameter(coefficients, requires_grad=trainable) self.exponents = nn.Parameter(exponents, requires_grad=trainable) def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: z = inputs[properties.Z] r_ij = inputs[properties.Rij] d_ij = torch.norm(r_ij, dim=1) idx_i = inputs[properties.idx_i] idx_j = inputs[properties.idx_j] idx_m = inputs[properties.idx_m] n_atoms = z.shape[0] n_molecules = int(idx_m[-1]) + 1 # Construct screening function a = z ** F.softplus(self.a_pow) a_ij = (a[idx_i] + a[idx_j]) * F.softplus(self.a_div) # Get exponents and coefficients, normalize the latter exponents = a_ij[..., None] * F.softplus(self.exponents)[None, ...] coefficients = F.softplus(self.coefficients)[None, ...] coefficients = F.normalize(coefficients, p=1.0, dim=1) screening = torch.sum( coefficients * torch.exp(-exponents * d_ij[:, None]), dim=1 ) # Compute nuclear repulsion repulsion = (z[idx_i] * z[idx_j]) / d_ij # Apply cutoff if requested if self.cutoff_fn is not None: f_cut = self.cutoff_fn(d_ij) repulsion = repulsion * f_cut # Compute ZBL energy y_zbl = snn.scatter_add(repulsion * screening, idx_i, dim_size=n_atoms) y_zbl = snn.scatter_add(y_zbl, idx_m, dim_size=n_molecules) y_zbl = 0.5 * self.ke * y_zbl inputs[self.output_key] = y_zbl return inputs