import torch
import torch.nn as nn
from typing import Union, Dict, Optional
from schnetpack import units as spk_units
import schnetpack.nn as snn
from schnetpack import properties
import numpy as np
__all__ = ["CoulombPotential", "DampedCoulombPotential", "EnergyCoulomb", "EnergyEwald"]
[docs]class CoulombPotential(nn.Module):
"""
Basic 1/r Coulomb component. For use in `schnetpack.atomistic.EnergyCoulomb`.
"""
def __init__(self):
super(CoulombPotential, self).__init__()
def forward(self, d_ij: torch.Tensor) -> torch.Tensor:
return 1.0 / d_ij
[docs]class DampedCoulombPotential(nn.Module):
"""
Compute a damped Coulomb potential as described in [#physnet1]_ For use in `schnetpack.atomistic.EnergyCoulomb`.
Args:
switch_fn (torch.nn.Module): Switch function.
References:
.. [#physnet1] O.Unke, M.Meuwly
PhysNet: A Neural Network for Predicting Energies, Forces, Dipole Moments and Partial Charges
https://arxiv.org/abs/1902.08408
"""
def __init__(self, switch_fn: nn.Module):
super(DampedCoulombPotential, self).__init__()
self.switch_fn = switch_fn
def forward(self, d_ij: torch.Tensor) -> torch.Tensor:
"""
Compute damped Coulomb potential.
Args:
d_ij (torch.Tensor): Interatomic distances.
Returns:
torch.Tensor: Damped potential
"""
potential = 1.0 / d_ij
damped = 1.0 / torch.sqrt(d_ij**2 + 1)
f_switch = self.switch_fn(d_ij)
return f_switch * damped + (1 - f_switch) * potential
[docs]class EnergyCoulomb(nn.Module):
"""
Compute Coulomb energy from a set of point charges via direct summation. Depending on the form of the
potential function, the interaction can be damped for short distances. If a cutoff is requested, the full
potential is shifted, so that it and its first derivative is zero starting from the cutoff.
Args:
energy_unit (str/float): Units used for the energy.
position_unit (str/float): Units used for lengths and positions.
coulomb_potential (torch.nn.Module): Distance part of the potential.
output_key (str): Name of the energy property in the output.
charges_key (str): Key of partial charges in the input batch.
use_neighbors_lr (bool): Whether to use standard or long range neighbor list elements (default = True).
cutoff (optional, float): Apply a long range cutoff (potential is shifted to 0, default=None).
"""
def __init__(
self,
energy_unit: Union[str, float],
position_unit: Union[str, float],
coulomb_potential: nn.Module,
output_key: str,
charges_key: str = properties.partial_charges,
use_neighbors_lr: bool = True,
cutoff: Optional[float] = None,
):
super(EnergyCoulomb, self).__init__()
# Get the appropriate Coulomb constant
ke = spk_units.convert_units("Ha", energy_unit) * spk_units.convert_units(
"Bohr", position_unit
)
self.register_buffer("ke", torch.Tensor([ke]))
self.coulomb_potential = coulomb_potential
self.charges_key = charges_key
self.output_key = output_key
self.model_outputs = [output_key]
self.use_neighbors_lr = use_neighbors_lr
if cutoff is not None:
cutoff = torch.tensor(cutoff)
shift = self.coulomb_potential(cutoff).detach()
self.register_buffer("cutoff", cutoff)
self.register_buffer("shift", shift)
else:
self.cutoff = None
self.shift = None
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Compute the Coulomb energy.
Args:
inputs (dict(str,torch.Tensor)): Input batch.
Returns:
dict(str, torch.Tensor): results with Coulomb energy.
"""
q = inputs[self.charges_key].squeeze(-1)
idx_m = inputs[properties.idx_m]
if self.use_neighbors_lr:
r_ij = inputs[properties.Rij_lr]
idx_i = inputs[properties.idx_i_lr]
idx_j = inputs[properties.idx_j_lr]
else:
r_ij = inputs[properties.Rij]
idx_i = inputs[properties.idx_i]
idx_j = inputs[properties.idx_j]
d_ij = torch.norm(r_ij, dim=1)
n_atoms = q.shape[0]
n_molecules = int(idx_m[-1]) + 1
q_ij = q[idx_i] * q[idx_j]
potential = self.coulomb_potential(d_ij)
# Apply cutoff if requested (shifting to zero)
if self.cutoff is not None:
potential = potential + self.shift**2 / potential - 2.0 * self.shift
potential = torch.where(
d_ij <= self.cutoff, potential, torch.zeros_like(potential)
)
y = snn.scatter_add((q_ij * potential), idx_i, dim_size=n_atoms)
y = snn.scatter_add(y, idx_m, dim_size=n_molecules)
y = 0.5 * self.ke * torch.squeeze(y, -1)
inputs[self.output_key] = y
return inputs
class EnergyEwaldError(Exception):
pass
[docs]class EnergyEwald(torch.nn.Module):
"""
Compute the Coulomb energy of a set of point charges inside a periodic box.
Only works for periodic boundary conditions in all three spatial directions and orthorhombic boxes.
Args:
alpha (float): Ewald alpha.
k_max (int): Number of lattice vectors.
energy_unit (str/float): Units used for the energy.
position_unit (str/float): Units used for lengths and positions.
output_key (str): Name of the energy property in the output.
charges_key (str): Key of partial charges in the input batch.
use_neighbors_lr (bool): Whether to use standard or long range neighbor list elements (default = True).
screening_fn (optional, float): Apply a screening function to the real space interaction.
"""
def __init__(
self,
alpha: float,
k_max: int,
energy_unit: Union[str, float],
position_unit: Union[str, float],
output_key: str,
charges_key: str = properties.partial_charges,
use_neighbors_lr: bool = True,
screening_fn: Optional[nn.Module] = None,
):
super(EnergyEwald, self).__init__()
# Get the appropriate Coulomb constant
ke = spk_units.convert_units("Ha", energy_unit) * spk_units.convert_units(
"Bohr", position_unit
)
self.register_buffer("ke", torch.Tensor([ke]))
self.charges_key = charges_key
self.output_key = output_key
self.model_outputs = [output_key]
self.use_neighbors_lr = use_neighbors_lr
self.screening_fn = screening_fn
# TODO: automatic computation of alpha
self.register_buffer("alpha", torch.Tensor([alpha]))
# Set up lattice vectors
self.k_max = k_max
kvecs = self._generate_kvecs()
self.register_buffer("kvecs", kvecs)
def _generate_kvecs(self) -> torch.Tensor:
"""
Auxiliary routine for setting up the k-vectors.
Returns:
torch.Tensor: k-vectors.
"""
krange = torch.arange(0, self.k_max + 1, dtype=self.alpha.dtype)
krange = torch.cat([krange, -krange[1:]])
kvecs = torch.cartesian_prod(krange, krange, krange)
norm = torch.sum(kvecs**2, dim=1)
kvecs = kvecs[norm <= self.k_max**2 + 2, :]
norm = norm[norm <= self.k_max**2 + 2]
kvecs = kvecs[norm != 0, :]
return kvecs
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Compute the Coulomb energy of the periodic system.
Args:
inputs (dict(str,torch.Tensor)): Input batch.
Returns:
dict(str, torch.Tensor): results with Coulomb energy.
"""
q = inputs[self.charges_key].squeeze(-1)
idx_m = inputs[properties.idx_m]
# Use long range neighbor list if requested
if self.use_neighbors_lr:
r_ij = inputs[properties.Rij_lr]
idx_i = inputs[properties.idx_i_lr]
idx_j = inputs[properties.idx_j_lr]
else:
r_ij = inputs[properties.Rij]
idx_i = inputs[properties.idx_i]
idx_j = inputs[properties.idx_j]
d_ij = torch.norm(r_ij, dim=1)
positions = inputs[properties.R]
cell = inputs[properties.cell]
n_atoms = q.shape[0]
n_molecules = int(idx_m[-1]) + 1
# Get real space and reciprocal space contributions
y_real = self._real_space(q, d_ij, idx_i, idx_j, idx_m, n_atoms, n_molecules)
y_reciprocal = self._reciprocal_space(q, positions, cell, idx_m, n_molecules)
y = y_real + y_reciprocal
inputs[self.output_key] = y
return inputs
def _real_space(
self,
q: torch.Tensor,
d_ij: torch.Tensor,
idx_i: torch.Tensor,
idx_j: torch.Tensor,
idx_m: torch.Tensor,
n_atoms: int,
n_molecules: int,
) -> torch.Tensor:
"""
Compute the real space contribution of the screened charges.
Args:
q (torch.Tensor): Partial charges.
d_ij (torch.Tensor): Interatomic distances.
idx_i (torch.Tensor): Indices of atoms i in the distance pairs.
idx_j (torch.Tensor): Indices of atoms j in the distance pairs.
idx_m (torch.Tensor): Molecular indices of each atom.
n_atoms (int): Total number of atoms.
n_molecules (int): Number of molecules.
Returns:
torch.Tensor: Real space Coulomb energy.
"""
# Apply erfc for Ewald summation
f_erfc = torch.erfc(torch.sqrt(self.alpha) * d_ij)
# Combine functions and multiply with inverse distance
f_r = f_erfc / d_ij
# Apply screening function
if self.screening_fn is not None:
screen = self.screening_fn(d_ij)
f_r = f_r * (1.0 - screen)
potential_ij = q[idx_i] * q[idx_j] * f_r
y = snn.scatter_add(potential_ij, idx_i, dim_size=n_atoms)
y = snn.scatter_add(y, idx_m, dim_size=n_molecules)
y = 0.5 * self.ke * y.squeeze(-1)
return y
def _reciprocal_space(
self,
q: torch.Tensor,
positions: torch.Tensor,
cell: torch.Tensor,
idx_m: torch.Tensor,
n_molecules: int,
):
"""
Compute the reciprocal space contribution.
Args:
q (torch.Tensor): Partial charges.
positions (torch.Tensor): Atom positions.
cell (torch.Tensor): Molecular cells.
idx_m (torch.Tensor): Molecular indices of each atom.
n_molecules (int): Number of molecules.
Returns:
torch.Tensor: Real space Coulomb energy.
"""
# extract box dimensions from cells
recip_box = 2.0 * np.pi * torch.linalg.inv(cell).transpose(1, 2)
v_box = torch.abs(torch.linalg.det(cell))
if torch.any(torch.isclose(v_box, torch.zeros_like(v_box))):
raise EnergyEwaldError("Simulation box has no volume.")
# 1) compute the prefactor
prefactor = 2.0 * np.pi / v_box
# setup kvecs M x K x 3
kvecs = torch.matmul(self.kvecs[None, :, :], recip_box)
# Squared length of vectors M x K
k_squared = torch.sum(kvecs**2, dim=2)
# 2) Gaussian part of ewald sum
q_gauss = torch.exp(-0.25 * k_squared / self.alpha) # M x K
# 3) Compute charge density fourier terms
# Dot product in exponent -> MN x K, expand kvecs in MN batch structure
kvec_dot_pos = torch.sum(kvecs[idx_m] * positions[:, None, :], dim=2)
# charge densities MN x K -> M x K
q_real = snn.scatter_add(
(q[:, None] * torch.cos(kvec_dot_pos)), idx_m, dim_size=n_molecules
)
q_imag = snn.scatter_add(
(q[:, None] * torch.sin(kvec_dot_pos)), idx_m, dim_size=n_molecules
)
# Compute square of density
q_dens = q_real**2 + q_imag**2
# Sum over k vectors -> M x K -> M
y_ewald = prefactor * torch.sum(q_dens * q_gauss / k_squared, dim=1)
# 4) self interaction correction -> MN
self_interaction = torch.sqrt(self.alpha / np.pi) * snn.scatter_add(
q**2, idx_m, dim_size=n_molecules
)
# Bring everything together
y_ewald = self.ke * (y_ewald - self_interaction)
return y_ewald