Source code for md.simulation_hooks.barostats_rpmd

"""
This module contains barostats for controlling the pressure of the system during
ring polymer molecular dynamics simulations.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from schnetpack.md import Simulator, System

import torch
from schnetpack.md.simulation_hooks import BarostatHook
from schnetpack import units as spk_units
from schnetpack.md.utils import StableSinhDiv

__all__ = ["PILEBarostat"]


[docs]class PILEBarostat(BarostatHook): """ Barostat for ring polymer molecular dynamics simulations. This barostat is based on the algorithm introduced in [#rpmd_barostat1]_ and uses a PILE thermostat for the cell. It can be combined with any RPMD thermostat for temperature control. The barostat only acts on the centroid mode of the ring polymer and is designed for isotropic cell fluctuations. Args: target_pressure (float): Target pressure of the system (in bar). temperature_bath (float): Target temperature applied to the cell fluctuations. detach (bool): Whether the computational graph should be detached after each simulation step. Default is true, should be changed if differentiable MD is desired. References ---------- .. [#rpmd_barostat1] Kapil, et al. i-PI 2.0: A universal force engine for advanced molecular simulations. Computer Physics Communications, 236, 214-223. 2019. """ temperature_control = False ring_polymer = True def __init__( self, target_pressure: float, temperature_bath: float, time_constant: float ): super(PILEBarostat, self).__init__( target_pressure=target_pressure, temperature_bath=temperature_bath, time_constant=time_constant, ) self.register_buffer("frequency", 1.0 / self.time_constant) # Compute kBT, since it will be used a lot self.register_buffer("kb_temperature", self.temperature_bath * spk_units.kB) self.register_uninitialized_buffer("propagator") self.register_uninitialized_buffer("cell_momenta") self.register_uninitialized_buffer("c1") self.register_uninitialized_buffer("c2") self.sinhdx = StableSinhDiv() def _init_barostat(self, simulator): """ Initialize the thermostat coefficients and barostat quantities. Args: simulator (schnetpack.simulation_hooks.simulator.Simulator): Main simulator class containing information on the time step, system, etc. """ # Get normal mode transformer and propagator for position and cell update self.propagator = simulator.integrator.propagator # Set up centroid momenta of cell (one for every molecule) self.cell_momenta = torch.zeros( simulator.system.n_molecules, device=simulator.device, dtype=simulator.dtype ) self.mass = ( 3.0 * simulator.system.n_atoms / self.frequency**2 * self.kb_temperature ).to(simulator.dtype) # Set up cell thermostat coefficients self.c1 = torch.exp( -0.5 * torch.ones(1, device=simulator.device, dtype=simulator.dtype) * self.frequency * self.time_step ) self.c2 = torch.sqrt( simulator.system.n_replicas * self.mass * self.kb_temperature * (1.0 - self.c1**2) ) def on_step_begin(self, simulator: Simulator): self._update_barostat() def on_step_end(self, simulator: Simulator): self._update_barostat() def _update_barostat(self): """ Apply the thermostat. This simply propagates the cell momenta under the influence of a PILE thermostat. """ # Propagate cell momenta during half-step self.cell_momenta = self.c1 * self.cell_momenta + self.c2 * torch.randn_like( self.cell_momenta ) def propagate_main_step(self, system: System): """ Main routine for propagating the ring polymer and the cells. The barostat acts only on the centroid, while the remaining replicas are propagated in the conventional way. Args: system (schnetpack.md.System): System class containing all molecules and their replicas. """ # Get coefficients reduced_momenta = (self.cell_momenta / self.mass)[None, :, None] coeff_a = torch.exp(-self.time_step * reduced_momenta) coeff_b = self.sinhdx.f(self.time_step * reduced_momenta) coeff_a_atomic = system.expand_atoms(coeff_a) coeff_b_atomic = system.expand_atoms(coeff_b) # Transform to normal mode representation positions_normal = system.positions_normal momenta_normal = system.momenta_normal # Propagate centroid mode of the ring polymer momenta_normal[0:1] = momenta_normal[0:1] * coeff_a_atomic positions_normal[0:1] = ( positions_normal[0:1] / coeff_a_atomic + coeff_b_atomic * (momenta_normal[0:1] / system.masses[0:1]) * self.time_step ) # Update cells system.cells = system.cells / coeff_a[..., None] # Create copies to avoid overwriting positions_normal_tmp = positions_normal.clone() momenta_normal_tmp = momenta_normal.clone() # Propagate the remaining modes of the ring polymer momenta_normal[1:] = ( self.propagator[1:, 0, 0] * momenta_normal_tmp[1:] + self.propagator[1:, 0, 1] * positions_normal_tmp[1:] * system.masses ) positions_normal[1:] = ( self.propagator[1:, 1, 0] * momenta_normal_tmp[1:] / system.masses + self.propagator[1:, 1, 1] * positions_normal_tmp[1:] ) # Transform back to bead representation system.positions_normal = positions_normal system.momenta_normal = momenta_normal def propagate_half_step(self, system: System): """ Propagate the momenta of the thermostat attached to the barostat during each half-step, as well as the particle momenta. Args: system (schnetpack.md.System): System class containing all molecules and their replicas. """ centroid_momenta = system.momenta_normal[0] centroid_forces = system.nm_transform.beads2normal(system.forces)[0] # Compute pressure component (volume[0] can be used, since the volume is scaled equally for all replicas) component_1 = ( 3.0 * system.n_replicas * ( torch.mean(system.volume, dim=0) * ( system.compute_centroid_pressure(kinetic_component=True)[0] - self.target_pressure ) + self.kb_temperature ) ) # Compute components based on forces and momenta force_by_mass = centroid_forces / system.masses[0] component_2 = system.sum_atoms( torch.sum(force_by_mass * centroid_momenta, dim=1)[None, ...] )[0] component_3 = system.sum_atoms( torch.sum(force_by_mass * centroid_forces / 3.0, dim=1)[None, ...] )[0] # Update cell momenta self.cell_momenta += ( +(0.5 * self.time_step) * component_1[:, 0] + (0.5 * self.time_step) ** 2 * component_2 + (0.5 * self.time_step) ** 3 * component_3 ) # Update the system momenta system.momenta = system.momenta + 0.5 * system.forces * self.time_step