"""
This module contains barostats for controlling the pressure of the system during
ring polymer molecular dynamics simulations.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from schnetpack.md import Simulator, System
import torch
from schnetpack.md.simulation_hooks import BarostatHook
from schnetpack import units as spk_units
from schnetpack.md.utils import StableSinhDiv
__all__ = ["PILEBarostat"]
[docs]class PILEBarostat(BarostatHook):
"""
Barostat for ring polymer molecular dynamics simulations. This barostat is based on the algorithm introduced in
[#rpmd_barostat1]_ and uses a PILE thermostat for the cell. It can be combined with any RPMD thermostat for
temperature control. The barostat only acts on the centroid mode of the ring polymer and is designed for isotropic
cell fluctuations.
Args:
target_pressure (float): Target pressure of the system (in bar).
temperature_bath (float): Target temperature applied to the cell fluctuations.
detach (bool): Whether the computational graph should be detached after each simulation step. Default is true,
should be changed if differentiable MD is desired.
References
----------
.. [#rpmd_barostat1] Kapil, et al.
i-PI 2.0: A universal force engine for advanced molecular simulations.
Computer Physics Communications, 236, 214-223. 2019.
"""
temperature_control = False
ring_polymer = True
def __init__(
self, target_pressure: float, temperature_bath: float, time_constant: float
):
super(PILEBarostat, self).__init__(
target_pressure=target_pressure,
temperature_bath=temperature_bath,
time_constant=time_constant,
)
self.register_buffer("frequency", 1.0 / self.time_constant)
# Compute kBT, since it will be used a lot
self.register_buffer("kb_temperature", self.temperature_bath * spk_units.kB)
self.register_uninitialized_buffer("propagator")
self.register_uninitialized_buffer("cell_momenta")
self.register_uninitialized_buffer("c1")
self.register_uninitialized_buffer("c2")
self.sinhdx = StableSinhDiv()
def _init_barostat(self, simulator):
"""
Initialize the thermostat coefficients and barostat quantities.
Args:
simulator (schnetpack.simulation_hooks.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
# Get normal mode transformer and propagator for position and cell update
self.propagator = simulator.integrator.propagator
# Set up centroid momenta of cell (one for every molecule)
self.cell_momenta = torch.zeros(
simulator.system.n_molecules, device=simulator.device, dtype=simulator.dtype
)
self.mass = (
3.0 * simulator.system.n_atoms / self.frequency**2 * self.kb_temperature
).to(simulator.dtype)
# Set up cell thermostat coefficients
self.c1 = torch.exp(
-0.5
* torch.ones(1, device=simulator.device, dtype=simulator.dtype)
* self.frequency
* self.time_step
)
self.c2 = torch.sqrt(
simulator.system.n_replicas
* self.mass
* self.kb_temperature
* (1.0 - self.c1**2)
)
def on_step_begin(self, simulator: Simulator):
self._update_barostat()
def on_step_end(self, simulator: Simulator):
self._update_barostat()
def _update_barostat(self):
"""
Apply the thermostat. This simply propagates the cell momenta under the influence of a PILE thermostat.
"""
# Propagate cell momenta during half-step
self.cell_momenta = self.c1 * self.cell_momenta + self.c2 * torch.randn_like(
self.cell_momenta
)
def propagate_main_step(self, system: System):
"""
Main routine for propagating the ring polymer and the cells. The barostat acts only on the centroid, while the
remaining replicas are propagated in the conventional way.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
# Get coefficients
reduced_momenta = (self.cell_momenta / self.mass)[None, :, None]
coeff_a = torch.exp(-self.time_step * reduced_momenta)
coeff_b = self.sinhdx.f(self.time_step * reduced_momenta)
coeff_a_atomic = system.expand_atoms(coeff_a)
coeff_b_atomic = system.expand_atoms(coeff_b)
# Transform to normal mode representation
positions_normal = system.positions_normal
momenta_normal = system.momenta_normal
# Propagate centroid mode of the ring polymer
momenta_normal[0:1] = momenta_normal[0:1] * coeff_a_atomic
positions_normal[0:1] = (
positions_normal[0:1] / coeff_a_atomic
+ coeff_b_atomic
* (momenta_normal[0:1] / system.masses[0:1])
* self.time_step
)
# Update cells
system.cells = system.cells / coeff_a[..., None]
# Create copies to avoid overwriting
positions_normal_tmp = positions_normal.clone()
momenta_normal_tmp = momenta_normal.clone()
# Propagate the remaining modes of the ring polymer
momenta_normal[1:] = (
self.propagator[1:, 0, 0] * momenta_normal_tmp[1:]
+ self.propagator[1:, 0, 1] * positions_normal_tmp[1:] * system.masses
)
positions_normal[1:] = (
self.propagator[1:, 1, 0] * momenta_normal_tmp[1:] / system.masses
+ self.propagator[1:, 1, 1] * positions_normal_tmp[1:]
)
# Transform back to bead representation
system.positions_normal = positions_normal
system.momenta_normal = momenta_normal
def propagate_half_step(self, system: System):
"""
Propagate the momenta of the thermostat attached to the barostat during each half-step, as well as the particle
momenta.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
centroid_momenta = system.momenta_normal[0]
centroid_forces = system.nm_transform.beads2normal(system.forces)[0]
# Compute pressure component (volume[0] can be used, since the volume is scaled equally for all replicas)
component_1 = (
3.0
* system.n_replicas
* (
torch.mean(system.volume, dim=0)
* (
system.compute_centroid_pressure(kinetic_component=True)[0]
- self.target_pressure
)
+ self.kb_temperature
)
)
# Compute components based on forces and momenta
force_by_mass = centroid_forces / system.masses[0]
component_2 = system.sum_atoms(
torch.sum(force_by_mass * centroid_momenta, dim=1)[None, ...]
)[0]
component_3 = system.sum_atoms(
torch.sum(force_by_mass * centroid_forces / 3.0, dim=1)[None, ...]
)[0]
# Update cell momenta
self.cell_momenta += (
+(0.5 * self.time_step) * component_1[:, 0]
+ (0.5 * self.time_step) ** 2 * component_2
+ (0.5 * self.time_step) ** 3 * component_3
)
# Update the system momenta
system.momenta = system.momenta + 0.5 * system.forces * self.time_step