"""
This module contains various barostats for controlling the pressure of the system during
molecular dynamics simulations.
"""
from __future__ import annotations
from typing import Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
from schnetpack.md import Simulator, System
import torch
from schnetpack import units as spk_units
# from schnetpack.md import System, Simulator
from schnetpack.md.simulation_hooks import SimulationHook
from schnetpack.md.utils import StableSinhDiv, YSWeights
__all__ = ["BarostatHook", "NHCBarostatIsotropic", "NHCBarostatAnisotropic"]
class BarostatError(Exception):
"""
Exception for barostat hooks.
"""
pass
[docs]class BarostatHook(SimulationHook):
"""
Basic barostat hook for simulator class. This class is initialized based on the simulator and system
specifications during the first MD step. Barostats are applied before and after each MD step. In addition,
they modify the update of positions and cells, which is why they have to be used with modified integrators.
Args:
target_pressure (float): Target pressure of the system (in bar).
temperature_bath (float): Target temperature applied to the cell fluctuations (in K).
time_constant (float): Time constant used for thermostatting if available (in fs).
"""
ring_polymer = False
temperature_control = False
def __init__(
self, target_pressure: float, temperature_bath: float, time_constant: float
):
super(BarostatHook, self).__init__()
# Convert pressure from bar to internal units
self.register_buffer(
"target_pressure", torch.tensor(target_pressure * spk_units.bar)
)
self.register_buffer("temperature_bath", torch.tensor(temperature_bath))
self.register_buffer(
"time_constant", torch.tensor(time_constant * spk_units.fs)
)
self.register_buffer("_initialized", torch.tensor(False))
# This should be determined automatically and does not need to be stored in buffer
self.time_step = None
@property
def initialized(self):
"""
Auxiliary property for easy access to initialized flag used for restarts
"""
return self._initialized.item()
@initialized.setter
def initialized(self, flag):
"""
Make sure initialized is set to torch.tensor for storage in state_dict.
"""
self._initialized = torch.tensor(flag)
def on_simulation_start(self, simulator: Simulator):
"""
Routine to initialize the barostat based on the current state of the simulator. Reads the device to be uses,
as well as the number of molecular replicas present in simulator.system. A flag is set so that the barostat
is not reinitialized upon continuation of the MD.
Main function is the _init_barostat routine, which takes the simulator as input and must be provided for every
new barostat.
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
self.time_step = simulator.integrator.time_step
if not self.initialized:
self._init_barostat(simulator)
self.initialized = True
def on_step_begin(self, simulator: Simulator):
"""
First application of the barostat before the first half step of the dynamics. Must be provided for
every new barostat. This is e.g. used to update the NHC thermostat chains on particles and cells in the NHC
barostats
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
raise NotImplementedError
def on_step_end(self, simulator: Simulator):
"""
Second application of the barostat before the first half step of the dynamics. Must be provided for
every new barostat. This is e.g. used to update the NHC thermostat chains on particles and cells in the NHC
barostats
Main function is the _apply_barostat routine, which takes the simulator as input and must be provided for
every new barostat.
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
raise NotImplementedError
def _init_barostat(self, simulator: Simulator):
"""
Dummy routine for initializing a barpstat based on the current simulator. Should be implemented for every
new barostat. Has access to the information contained in the simulator class, e.g. number of replicas, time
step, masses of the atoms, etc.
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
raise NotImplementedError
def propagate_main_step(self, system: System):
"""
Propagate the system under the conditions imposed by the barostat. This has to be adapted to the specific
barostat algorithm and should propagate the positons, cells and momenta (for RPMD). Defaults to classic Verlet.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
system.positions = (
system.positions + self.time_step * system.momenta / system.masses
)
def propagate_half_step(self, system: System):
"""
Propagate system under barostat conditions during half steps typically used for momenta. Should be adapted for
each barostat (e.g. NHC propagates paricle and barostat momenta due to barostat actions. Defaults to classic
Verlet.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
system.momenta = system.momenta + 0.5 * system.forces * self.time_step
[docs]class NHCBarostatIsotropic(BarostatHook):
"""
Nose Hoover chain thermostat/barostat for isotropic cell fluctuations. This barostat already contains a built in
thermostat, so no further temperature control is necessary. As suggested in [#nhc_barostat1]_, two separate chains
are used to thermostat particle and cell momenta.
Args:
target_pressure (float): Target pressure in bar.
temperature_bath (float): Temperature of the external heat bath in Kelvin.
time_constant (float): Particle thermostat time constant in fs
time_constant_cell (float): Cell thermostat time constant in fs. If None is given (default), the same time
constant as for the thermostat component is used.
time_constant_barostat (float): Barostat time constant in fs. If None is given (default), the same time constant
as for the thermostat component is used.
chain_length (int): Number of Nose-Hoover thermostats applied in the chain.
multi_step (int): Number of steps used for integrating the NH equations of motion (default=2)
integration_order (int): Order of the Yoshida-Suzuki integrator used for propagating the thermostat (default=3).
massive (bool): Apply individual thermostat chains to all particle degrees of freedom (default=False).
References
----------
.. [#nhc_barostat1] Martyna, Tuckerman, Tobias, Klein:
Explicit reversible integrators for extended systems dynamics.
Molecular Physics, 87(5), 1117-1157. 1996.
"""
temperature_control = True
ring_polymer = False
def __init__(
self,
target_pressure: float,
temperature_bath: float,
time_constant: float,
time_constant_cell: Optional[float] = None,
time_constant_barostat: Optional[float] = None,
chain_length: Optional[int] = 4,
multi_step: Optional[int] = 4,
integration_order: Optional[int] = 7,
massive: Optional[bool] = False,
):
super(NHCBarostatIsotropic, self).__init__(
target_pressure=target_pressure,
temperature_bath=temperature_bath,
time_constant=time_constant,
)
# Thermostat, cell thermostat and barostat frequencies
self.register_buffer("frequency", 1.0 / self.time_constant)
if time_constant_cell is None:
cell_frequency = self.frequency
else:
cell_frequency = torch.tensor(1.0 / (time_constant_cell * spk_units.fs))
if time_constant_barostat is None:
barostat_frequency = self.frequency
else:
barostat_frequency = torch.tensor(
1.0 / (time_constant_barostat * spk_units.fs)
)
self.register_buffer("cell_frequency", cell_frequency)
self.register_buffer("barostat_frequency", barostat_frequency)
# Cpmpute kBT, since it will be used a lot
self.register_buffer("kb_temperature", self.temperature_bath * spk_units.kB)
# Propagation parameters
self.register_buffer("chain_length", torch.tensor(chain_length))
self.register_buffer("massive", torch.tensor(massive))
self.register_buffer("multi_step", torch.tensor(multi_step))
self.register_buffer("integration_order", torch.tensor(integration_order))
self.register_uninitialized_buffer("ys_time_step")
# Find out number of particles (depends on whether massive or not)
self.register_uninitialized_buffer("degrees_of_freedom")
self.register_uninitialized_buffer("degrees_of_freedom_cell")
self.register_uninitialized_buffer("degrees_of_freedom_particles")
# Thermostat variables for particles
self.register_uninitialized_buffer("t_velocities")
self.register_uninitialized_buffer("t_positions")
self.register_uninitialized_buffer("t_forces")
self.register_uninitialized_buffer("t_masses")
# Thermostat variables for cell
self.register_uninitialized_buffer("t_velocities_cell")
self.register_uninitialized_buffer("t_positions_cell")
self.register_uninitialized_buffer("t_forces_cell")
self.register_uninitialized_buffer("t_masses_cell")
# Barostat variables
self.register_uninitialized_buffer("b_velocities_cell")
self.register_uninitialized_buffer("b_positions_cell")
self.register_uninitialized_buffer("b_forces_cell")
self.register_uninitialized_buffer("b_masses_cell")
# Stable sinh(x)/x approximation
self.sinhdx = StableSinhDiv()
def _init_barostat(self, simulator: Simulator):
"""
Initialize the thermostat positions, forces, velocities and masses, as well as the number of degrees of freedom
seen by each chain link. In the same manner, all quantities required for the barostat are initialized.
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
# Determine integration step via multi step and Yoshida Suzuki weights
integration_weights = (
YSWeights()
.get_weights(self.integration_order.item())
.to(simulator.device, simulator.dtype)
)
self.ys_time_step = (
simulator.integrator.time_step * integration_weights / self.multi_step
)
# Determine internal degrees of freedom for barostat
self.degrees_of_freedom = (
3.0 * simulator.system.n_atoms.to(simulator.dtype)[None, :, None]
)
# Determine degrees of freedom for particle thermostats (depends on massive)
n_replicas = simulator.system.n_replicas
n_molecules = simulator.system.n_molecules
n_atoms_total = simulator.system.total_n_atoms
if self.massive:
state_dimension = (n_replicas, n_atoms_total, 3, self.chain_length)
self.degrees_of_freedom_particles = torch.ones(
(n_replicas, n_atoms_total, 3),
device=simulator.device,
dtype=simulator.dtype,
)
else:
state_dimension = (n_replicas, n_molecules, 1, self.chain_length)
self.degrees_of_freedom_particles = self.degrees_of_freedom
# Set up internal variables
self._init_barostat_variables(simulator)
self._init_thermostat_variables(state_dimension, simulator)
def _init_barostat_variables(self, simulator: Simulator):
"""
Initialize all quantities required for the barostat component.
"""
# Set barostat masses
self.b_masses_cell = torch.ones(
(simulator.system.n_replicas, simulator.system.n_molecules, 1),
device=simulator.device,
dtype=simulator.dtype,
)
self.b_masses_cell = (
(self.degrees_of_freedom + 3)
* self.kb_temperature
/ self.barostat_frequency**2
)
# Remaining barostat variables
self.b_velocities_cell = torch.zeros_like(self.b_masses_cell)
self.b_forces_cell = torch.zeros_like(self.b_masses_cell)
# Set cell degrees of freedom (1 for isotropic, 9 for full, 9-3 for full with symmetric pressure (no rotations)
self.degrees_of_freedom_cell = torch.tensor(
1, dtype=simulator.dtype, device=simulator.device
)
def _init_thermostat_variables(
self, state_dimension: Tuple[int, int, int, int], simulator: Simulator
):
"""
Initialize all quantities required for the two thermostat chains on the particles and cells.
Args:
state_dimension (tuple): Size of the thermostat states. This is used to differentiate between the massive
and the standard algorithm
simulator (schnetpack.simulator.Simulator): Main simulator class containing information on the time step,
system, etc.
"""
# Set up thermostat masses
self.t_masses = torch.zeros(
state_dimension, device=simulator.device, dtype=simulator.dtype
)
self.t_masses_cell = torch.zeros(
(
simulator.system.n_replicas,
simulator.system.n_molecules,
1,
self.chain_length,
),
device=simulator.device,
dtype=simulator.dtype,
)
# Get masses of innermost thermostat
self.t_masses[..., 0] = (
self.degrees_of_freedom_particles * self.kb_temperature / self.frequency**2
)
# Get masses of cell
self.t_masses_cell[..., 0] = (
self.degrees_of_freedom_cell * self.kb_temperature / self.cell_frequency**2
)
# Set masses of remaining thermostats
self.t_masses[..., 1:] = self.kb_temperature / self.frequency**2
self.t_masses_cell[..., 1:] = self.kb_temperature / self.cell_frequency**2
# Thermostat variables for particles
self.t_velocities = torch.zeros_like(self.t_masses)
self.t_forces = torch.zeros_like(self.t_masses)
# Thermostat variables for cell
self.t_velocities_cell = torch.zeros_like(self.t_masses_cell)
self.t_forces_cell = torch.zeros_like(self.t_masses_cell)
# Positions for conservation
# self.t_positions = torch.zeros_like(self.t_masses, device=self.device)
# self.t_positions_cell = torch.zeros_like(self.t_masses_cell, device=self.device)
def on_step_begin(self, simulator: Simulator):
"""
Propagate the thermostat chains on particles and cell and update cell velocities using the barostat forces.
Args:
simulator (schnetpack.simulator.Simulator): Main simulator class containing information on the time step,
system, etc.
"""
self._update_thermostat(simulator)
self._update_barostat(simulator)
def on_step_end(self, simulator: Simulator):
"""
Propagate the thermostat chains on particles and cell and update cell velocities using the barostat forces.
Order is reversed in order to preserve symmetric splitting of the overall propagator.
Args:
simulator (schnetpack.simulator.Simulator): Main simulator class containing information on the time step,
system, etc.
"""
self._update_barostat(simulator)
self._update_thermostat(simulator)
def _update_thermostat(self, simulator: Simulator):
"""
Apply the thermostat chains to the system and cell momenta. This is done by computing forces of the innermost
thermostats, propagating the chain forward, updating the box velocities, particle momenta and associated
energies. Based on this, the kinetic energies and forces can be updated, which are the propagated backward
through the chain.
Args:
simulator (schnetpack.md.simulator.Simulator): Main simulator class containing information on
the time step, system, etc.
"""
# Get kinetic energies for cell and particles
kinetic_energy_particles = self._compute_kinetic_energy_particles(
simulator.system
)
kinetic_energy_cell = self._compute_kinetic_energy_cell()
# Update the innermost thermostat forces using the current kinetic energies
self._update_inner_t_forces(kinetic_energy_particles)
self._update_inner_t_forces_cell(kinetic_energy_cell)
# Initialize scaling factor which will be accumulated
scaling_thermostat_particles = torch.ones_like(self.t_velocities[..., 0])
scaling_thermostat_cell = torch.ones_like(self.t_velocities_cell[..., 0])
# Multistep and YS procedure for propagating the thermostat operators
for _ in range(self.multi_step):
for idx_ys in range(self.integration_order):
# Determine integration time step
time_step = self.ys_time_step[idx_ys]
# Propagate the chain inward (forces are not updated, the innermost forces are already initialized)
self._chain_forward(time_step)
# Accumulate scaling
scaling_thermostat_particles *= torch.exp(
-0.5 * time_step * self.t_velocities[..., 0]
)
scaling_thermostat_cell *= torch.exp(
-0.5 * time_step * self.t_velocities_cell[..., 0]
)
# Recompute forces using scaling to update the kinetic energies
self._update_inner_t_forces(
kinetic_energy_particles * scaling_thermostat_particles**2
)
self._update_inner_t_forces_cell(
kinetic_energy_cell * scaling_thermostat_cell**2
)
# Update velocities and forces of remaining thermostats
self._chain_backward(time_step)
# Update the momenta of the particles
if not self.massive:
scaling_thermostat_particles = simulator.system.expand_atoms(
scaling_thermostat_particles
)
simulator.system.momenta = (
simulator.system.momenta * scaling_thermostat_particles
)
# Update cell momenta
self._scale_cell(scaling_thermostat_cell)
def _scale_cell(self, scaling_thermostat_cell: torch.tensor):
"""
Auxiliary routine for scaling the cell, since the scaling factor will be missing one dimension in the
anisotropic barostat
Args:
scaling_thermostat_cell (torch.tensor): Accumulated scaling for cell velocities
"""
self.b_velocities_cell = self.b_velocities_cell * scaling_thermostat_cell
def _compute_kinetic_energy_particles(self, system: System):
"""
Compute the current kinetic energy for the particles, depending on whether massive or normal thermostatting is
used.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
Returns:
torch.tensor: Current kinetic energy of the particles.
"""
if self.massive:
kinetic_energy_particles = system.momenta**2 / system.masses
else:
kinetic_energy_particles = 2.0 * system.kinetic_energy
return kinetic_energy_particles
def _compute_kinetic_energy_cell(self):
"""
Compute the kinetic energy of the cells.
Returns:
torch.tensor: Kinetic energy associated with the cells.
"""
return self.b_masses_cell * self.b_velocities_cell**2
def _update_inner_t_forces(self, kinetic_energy_particles: torch.tensor):
"""
Update the forces acting on the innermost chain of the particle thermostat.
Args:
kinetic_energy_particles (torch.tensor): kinetic energy of the particles
"""
self.t_forces[..., 0] = (
kinetic_energy_particles
- self.degrees_of_freedom_particles * self.kb_temperature
) / self.t_masses[..., 0]
def _update_inner_t_forces_cell(self, kinetic_energy_cell: torch.tensor):
"""
Update the forces acting on the innermost chain of the cell thermostat.
Args:
kinetic_energy_cell (torch.tensor): kinetic energy of the cell
"""
self.t_forces_cell[..., 0] = (
kinetic_energy_cell - self.degrees_of_freedom_cell * self.kb_temperature
) / self.t_masses_cell[..., 0]
def _chain_forward(self, time_step: float):
"""
Forward propagation of the two Nose-Hoover chains attached to particles and cells.
Force updates are not required here, as long as the innermost force is precomputed, since forces are effectively
taken from the previous step and everything gets gets overwritten by the force update in the backward chain.
Args:
time_step (float): Current timestep considering YS and multi-timestep integration.
"""
# Update velocities of outermost bath
self.t_velocities[..., -1] += 0.25 * self.t_forces[..., -1] * time_step
self.t_velocities_cell[..., -1] += (
0.25 * self.t_forces_cell[..., -1] * time_step
)
# Update the velocities moving through the beads of the chain
for chain in range(self.chain_length - 2, -1, -1):
t_coeff = torch.exp(-0.125 * time_step * self.t_velocities[..., chain + 1])
b_coeff = torch.exp(
-0.125 * time_step * self.t_velocities_cell[..., chain + 1]
)
self.t_velocities[..., chain] = (
self.t_velocities[..., chain] * t_coeff**2
+ 0.25 * self.t_forces[..., chain] * t_coeff * time_step
)
self.t_velocities_cell[..., chain] = (
self.t_velocities_cell[..., chain] * b_coeff**2
+ 0.25 * self.t_forces_cell[..., chain] * b_coeff * time_step
)
def _chain_backward(self, time_step: float):
"""
Backward propagation of the two Nose-Hoover chains attached to particles and cells. In addition, the repsective
thermostat forces are updated.
Args:
time_step (float): Current timestep considering YS and multi-timestep integration.
"""
# Update the thermostat velocities
for chain in range(self.chain_length - 1):
t_coeff = torch.exp(-0.125 * time_step * self.t_velocities[..., chain + 1])
b_coeff = torch.exp(
-0.125 * time_step * self.t_velocities_cell[..., chain + 1]
)
self.t_velocities[..., chain] = (
self.t_velocities[..., chain] * t_coeff**2
+ 0.25 * self.t_forces[..., chain] * t_coeff * time_step
)
self.t_velocities_cell[..., chain] = (
self.t_velocities_cell[..., chain] * b_coeff**2
+ 0.25 * self.t_forces_cell[..., chain] * b_coeff * time_step
)
# Update forces through chain
self.t_forces[..., chain + 1] = (
self.t_masses[..., chain] * self.t_velocities[..., chain] ** 2
- self.kb_temperature
) / self.t_masses[..., chain + 1]
self.t_forces_cell[..., chain + 1] = (
self.t_masses_cell[..., chain] * self.t_velocities_cell[..., chain] ** 2
- self.kb_temperature
) / self.t_masses_cell[..., chain + 1]
# Update velocities of outermost thermostat
self.t_velocities[..., -1] += 0.25 * self.t_forces[..., -1] * time_step
self.t_velocities_cell[..., -1] += (
0.25 * self.t_forces_cell[..., -1] * time_step
)
def _update_barostat(self, simulator: Simulator):
# Get new barostat forces
self._update_b_forces(simulator.system)
# Update the cell velocities
self.b_velocities_cell = (
self.b_velocities_cell + 0.5 * self.time_step * self.b_forces_cell
)
def _update_b_forces(self, system: System):
"""
Update the barostat forces using kinetic energy, current pressure and volume of the system.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
# Get the pressure (R x M x 1)
pressure = system.compute_pressure(kinetic_component=False, tensor=False)
# Get the volume (R x M x 1)
volume = system.volume
# Get the kinetic energy
kinetic_energy = 2.0 * system.kinetic_energy
self.b_forces_cell = (
(1.0 + 3.0 / self.degrees_of_freedom) * kinetic_energy
+ 3.0 * volume * (pressure - self.target_pressure)
) / self.b_masses_cell
def propagate_main_step(self, system: System):
"""
Main routine for propagating the system positions and cells. Since this is modified, no conventional velocity
verlet integrator can be used.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
scaled_velocity = self.time_step * self.b_velocities_cell
# Compute exponential coefficient
a_coeff = torch.exp(0.5 * scaled_velocity)
# Compute sinh(x)/x term
b_coeff = a_coeff * self.sinhdx.f(0.5 * scaled_velocity)
# Update the particle positions
system.positions = (
system.positions * system.expand_atoms(a_coeff**2)
+ system.momenta
/ system.masses
* system.expand_atoms(b_coeff)
* self.time_step
)
# Scale the cells (propagation is in logarithmic space)
cell_coeff = torch.exp(scaled_velocity)[..., None]
system.cells = system.cells * cell_coeff
def propagate_half_step(self, system: System):
"""
Main routine for propagating the system momenta. Since this is modified, no conventional velocity
verlet integrator can be used.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
# Compute basic argument
scaled_velocity = (
0.25
* self.time_step
* self.b_velocities_cell
* (1.0 + 3.0 / self.degrees_of_freedom)
)
# Compute exponential coefficient
a_coeff = torch.exp(-scaled_velocity)
# Compute sinh(x)/x term
b_coeff = a_coeff * self.sinhdx.f(scaled_velocity)
# Update the momenta (using half timestep)
system.momenta = (
system.momenta * system.expand_atoms(a_coeff**2)
+ system.forces * system.expand_atoms(b_coeff) * self.time_step * 0.5
)
# def compute_conserved(self, system):
# """
# Computed the conserved quantity. For debug purposes only.
# """
# conserved = (
# system.kinetic_energy[..., None, None]
# + system.energies[..., None, None]
# + 0.5 * torch.sum(self.t_velocities ** 2 * self.t_masses, 2)
# + 0.5 * torch.sum(self.t_velocities_cell ** 2 * self.t_masses_cell, 2)
# + 0.5 * self.b_velocities_cell ** 2 * self.b_masses_cell
# + self.degrees_of_freedom * self.kb_temperature * self.t_positions[..., 0]
# + self.kb_temperature * self.t_positions_cell[..., 0]
# + self.kb_temperature * torch.sum(self.t_positions[..., 1:], 2)
# + self.kb_temperature * torch.sum(self.t_positions_cell[..., 1:], 2)
# + self.target_pressure * system.volume
# )
# return conserved
[docs]class NHCBarostatAnisotropic(NHCBarostatIsotropic):
"""
Nose Hoover chain thermostat/barostat for anisotropic cell fluctuations. This barostat already contains a built in
thermostat, so no further temperature control is necessary. As suggested in [#nhc_barostat1]_, two separate chains
are used to thermostat particle and cell momenta.
Args:
target_pressure (float): Target pressure in bar.
temperature_bath (float): Temperature of the external heat bath in Kelvin.
time_constant (float): Particle thermostat time constant in fs
time_constant_cell (float): Cell thermostat time constant in fs. If None is given (default), the same time
constant as for the thermostat component is used.
time_constant_barostat (float): Barostat time constant in fs. If None is given (default), the same time constant
as for the thermostat component is used.
chain_length (int): Number of Nose-Hoover thermostats applied in the chain.
multi_step (int): Number of steps used for integrating the NH equations of motion (default=2)
integration_order (int): Order of the Yoshida-Suzuki integrator used for propagating the thermostat (default=3).
massive (bool): Apply individual thermostat chains to all particle degrees of freedom (default=False).
References
----------
.. [#nhc_barostat1] Martyna, Tuckerman, Tobias, Klein:
Explicit reversible integrators for extended systems dynamics.
Molecular Physics, 87(5), 1117-1157. 1996.
"""
temperature_control = True
ring_polymer = False
def __init__(
self,
target_pressure: float,
temperature_bath: float,
time_constant: float,
time_constant_cell: Optional[float] = None,
time_constant_barostat: Optional[float] = None,
chain_length: Optional[int] = 4,
multi_step: Optional[int] = 4,
integration_order: Optional[int] = 7,
massive: Optional[bool] = False,
):
super(NHCBarostatAnisotropic, self).__init__(
target_pressure=target_pressure,
temperature_bath=temperature_bath,
time_constant=time_constant,
time_constant_cell=time_constant_cell,
time_constant_barostat=time_constant_barostat,
chain_length=chain_length,
multi_step=multi_step,
integration_order=integration_order,
massive=massive,
)
def _init_barostat_variables(self, simulator: Simulator):
"""
Initialize all quantities required for the barostat component.
"""
# Set barostat masses
self.b_masses_cell = torch.ones(
(simulator.system.n_replicas, simulator.system.n_molecules, 1),
device=simulator.device,
dtype=simulator.dtype,
)
# Modified due to full cell
self.b_masses_cell = (
(self.degrees_of_freedom + 3)
* self.kb_temperature
/ self.barostat_frequency**2
/ 3.0
)
# Remaining barostat variables (forces and velocities are now 3 x 3)
self.b_velocities_cell = torch.zeros(
(simulator.system.n_replicas, simulator.system.n_molecules, 3, 3),
device=simulator.device,
dtype=simulator.dtype,
)
self.b_forces_cell = torch.zeros_like(self.b_velocities_cell)
# Auxiliary identity matrix for broadcasting
self.register_buffer(
"aux_eye",
torch.eye(3, device=simulator.device, dtype=simulator.dtype)[
None, None, :, :
],
)
# Set cell degrees of freedom (1 for isotropic, 9 for full, 9-3 for full with symmetric pressure (no rotations)
self.degrees_of_freedom_cell = torch.tensor(
6, dtype=simulator.dtype, device=simulator.device
)
def _scale_cell(self, scaling_thermostat_cell: torch.tensor):
"""
Auxiliary routine for scaling the cell, here the scaling factor needs one additional dimension compared to the
isotropic case.
Args:
scaling_thermostat_cell (torch.tensor): Accumulated scaling for cell velocities
"""
self.b_velocities_cell = (
self.b_velocities_cell * scaling_thermostat_cell[..., None]
)
def _compute_kinetic_energy_cell(self):
"""
Compute the kinetic energy of the cells.
Returns:
torch.tensor: Kinetic energy associated with the cells.
"""
b_velocities_cell_sq = torch.sum(
self.b_velocities_cell**2, dim=(2, 3), keepdim=True
).squeeze(-1)
return self.b_masses_cell * b_velocities_cell_sq
def _update_b_forces(self, system: System):
"""
Update the barostat forces using kinetic energy, current pressure and volume of the system.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
# Get the pressure (R x M x 1)
pressure = system.compute_pressure(kinetic_component=True, tensor=True)
# Get the volume (R x M x 1)
volume = system.volume
# Get the kinetic energy
kinetic_energy = 2.0 * system.kinetic_energy
self.b_forces_cell = (
volume[..., None] * (pressure - self.aux_eye * self.target_pressure)
+ kinetic_energy[..., None]
/ self.degrees_of_freedom[..., None]
* self.aux_eye
) / self.b_masses_cell[..., None]
def propagate_main_step(self, system: System):
"""
Main routine for propagating the system positions and cells. Since this is modified, no conventional velocity
verlet integrator can be used.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
eigval_b_velocities, eigvec_b_velocities = torch.symeig(
self.b_velocities_cell, eigenvectors=True
)
scaled_velocity = eigval_b_velocities * self.time_step
coeff_a = torch.exp(0.5 * scaled_velocity)[..., None] * self.aux_eye
coeff_b = coeff_a * self.sinhdx.f(0.5 * scaled_velocity)[..., None]
# Construct matrix operators and update positions using positions and momenta
operator_a = torch.matmul(
eigvec_b_velocities,
torch.matmul(coeff_a**2, eigvec_b_velocities.transpose(2, 3)),
)
operator_b = torch.matmul(
eigvec_b_velocities,
torch.matmul(coeff_b, eigvec_b_velocities.transpose(2, 3)),
)
update_positions = torch.sum(
system.expand_atoms(operator_a) * system.positions[..., None], dim=2
)
update_momenta = torch.sum(
system.expand_atoms(operator_b)
* (system.momenta / system.masses)[..., None],
dim=2,
)
system.positions = update_positions + update_momenta * self.time_step
# Update cells using first operator
system.cells = torch.matmul(system.cells, operator_a)
def propagate_half_step(self, system: System):
"""
Main routine for propagating the system momenta. Since this is modified, no conventional velocity
verlet integrator can be used.
Args:
system (schnetpack.md.System): System class containing all molecules and their
replicas.
"""
eigval_b_velocities, eigvec_b_velocities = torch.symeig(
self.b_velocities_cell, eigenvectors=True
)
# Trace of matrix is sum of eigenvalues
trace_b_velocities = torch.sum(eigval_b_velocities, dim=2, keepdim=True)
scaled_velocity = (
(eigval_b_velocities + trace_b_velocities / self.degrees_of_freedom)
* self.time_step
* 0.5
)
coeff_a = torch.exp(-0.5 * scaled_velocity)[..., None] * self.aux_eye
coeff_b = coeff_a * self.sinhdx.f(0.5 * scaled_velocity)[..., None]
# Construct matrix operators and update positions using positions and momenta
operator_a = torch.matmul(
eigvec_b_velocities,
torch.matmul(coeff_a**2, eigvec_b_velocities.transpose(2, 3)),
)
operator_b = torch.matmul(
eigvec_b_velocities,
torch.matmul(coeff_b, eigvec_b_velocities.transpose(2, 3)),
)
update_momenta = torch.sum(
system.expand_atoms(operator_a) * system.momenta[..., None], dim=2
)
update_forces = torch.sum(
system.expand_atoms(operator_b) * system.forces[..., None], dim=2
)
system.momenta = update_momenta + 0.5 * self.time_step * update_forces