"""
All molecular dynamics in SchNetPack is performed using the :obj:`schnetpack.md.Simulator` class.
This class collects the atomistic system (:obj:`schnetpack.md.System`), calculators (:obj:`schnetpack.md.calculators`),
integrators (:obj:`schnetpack.md.integrators`) and various simulation hooks (:obj:`schnetpack.md.simulation_hooks`)
and performs the time integration.
"""
import torch
import torch.nn as nn
from contextlib import nullcontext
from tqdm import trange
from schnetpack.md import System
__all__ = ["Simulator"]
[docs]class Simulator(nn.Module):
"""
Main driver of the molecular dynamics simulation. Uses an integrator to
propagate the molecular system defined in the system class according to
the forces yielded by a provided calculator.
In addition, hooks can be applied at five different stages of each
simulation step:
- Start of the simulation (e.g. for initializing thermostat)
- Before first integrator half step (e.g. thermostat)
- After computation of the forces and before main integrator step (e.g.
for accelerated MD)
- After second integrator half step (e.g. thermostat, output routines)
- At the end of the simulation (e.g. general wrap up of file writes, etc.)
This routine has a state dict which can be used to restart a previous
simulation.
Args:
system (schnetpack.md.System): Instance of the system class defined in
molecular_dynamics.system holding the structures,
masses, atom type, momenta, forces and properties of
all molecules and their replicas
integrator (schnetpack.md.Integrator): Integrator for propagating the molecular
dynamics simulation, defined in
schnetpack.md.integrators
calculator (schnetpack.md.calculator): Calculator class used to compute molecular
forces for propagation and (if requested)
various other properties.
simulator_hooks (list(object)): List of different hooks to be applied
during simulations. Examples would be
file loggers and thermostat.
step (int): Index of the initial simulation step.
restart (bool): Indicates, whether the simulation is restarted. E.g. if set to True, the simulator tries to
continue logging in the previously created dataset. (default=False)
This is set automatically by the restart_simulation function. Enabling it without the function
currently only makes sense if independent simulations should be written to the same file.
progress (bool): show progress bar during simulation. Can be deactivated e.g. for cluster runs.
"""
def __init__(
self,
system: System,
integrator,
calculator,
simulator_hooks: list = [],
step: int = 0,
restart: bool = False,
gradients_required: bool = False,
progress: bool = True,
):
super(Simulator, self).__init__()
self.system = system
self.integrator = integrator
self.calculator = calculator
self.simulator_hooks = torch.nn.ModuleList(simulator_hooks)
self.step = step
self.n_steps = None
self.restart = restart
self.gradients_required = gradients_required
self.progress = progress
# Keep track of the actual simulation steps performed with simulate calls
self.effective_steps = 0
@property
def device(self):
return self.system.device
@property
def dtype(self):
return self.system.dtype
def simulate(self, n_steps: int):
"""
Main simulation function. Propagates the system for a certain number
of steps.
Args:
n_steps (int): Number of simulation steps to be performed.
"""
self.n_steps = n_steps
# Determine iterator
if self.progress:
iterator = trange
else:
iterator = range
# Check, if computational graph should be built
if self.gradients_required:
grad_context = torch.no_grad()
else:
grad_context = nullcontext()
with grad_context:
# Perform initial computation of forces
self.calculator.calculate(self.system)
# Call hooks at the simulation start
for hook in self.simulator_hooks:
hook.on_simulation_start(self)
for _ in iterator(n_steps):
# Call hook before first half step
for hook in self.simulator_hooks:
hook.on_step_begin(self)
# Do half step momenta
self.integrator.half_step(self.system)
# Do propagation MD/PIMD
self.integrator.main_step(self.system)
# Compute new forces
self.calculator.calculate(self.system)
# Call hook after forces
for hook in self.simulator_hooks:
hook.on_step_middle(self)
# Do half step momenta
self.integrator.half_step(self.system)
# Call hooks after second half step
# Hooks are called in reverse order to guarantee symmetry of
# the propagator when using thermostat and barostats
for hook in self.simulator_hooks[::-1]:
hook.on_step_end(self)
# Logging hooks etc
for hook in self.simulator_hooks:
hook.on_step_finalize(self)
self.step += 1
self.effective_steps += 1
# Call hooks at the simulation end
for hook in self.simulator_hooks:
hook.on_simulation_end(self)
@property
def state_dict(self):
"""
State dict used to restart the simulation. Generates a dictionary with
the following entries:
- step: current simulation step
- systems: state dict of the system holding current positions,
momenta, forces, etc...
- simulator_hooks: dict of state dicts of the various hooks used
during simulation using their basic class
name as keys.
Returns:
dict: State dict containing the current step, the system
parameters (positions, momenta, etc.) and all
simulator_hook state dicts
"""
state_dict = {
"step": self.step,
"system": self.system.state_dict(),
"simulator_hooks": {
hook.__class__: hook.state_dict() for hook in self.simulator_hooks
},
}
return state_dict
@state_dict.setter
def state_dict(self, state_dict):
"""
Set the current state dict of the simulator using a state dict
defined in state_dict. This routine assumes, that the identity of all
hooks has not changed and the order is preserved. A more general
method to restart simulations is provided below.
Args:
state_dict (dict): state dict containing the entries 'step',
'simulator_hooks' and 'system'.
"""
self.step = state_dict["step"]
self.system.load_state_dict(state_dict["system"])
# Set state dicts of all hooks
for hook in self.simulator_hooks:
if hook.__class__ in state_dict["simulator_hooks"]:
hook.load_state_dict(state_dict["simulator_hooks"][hook.__class__])
def restart_simulation(self, state_dict, soft=False):
"""
Routine for restarting a simulation. Reads the current step, as well
as system state from the provided state dict. In case of the
simulation hooks, only the states of the thermostat hooks are
restored, as all other hooks do not depend on previous simulations.
If the soft option is chosen, only restores states of thermostat if
they are present in the current simulation and the state dict.
Otherwise, all thermostat found in the state dict are required to be
present in the current simulation.
Args:
state_dict (dict): State dict of the current simulation
soft (bool): Flag to toggle hard/soft thermostat restarts (
default=False)
"""
# TODO: restart with metadynamics hooks etc, ?
self.step = state_dict["step"]
self.system.load_system_state(state_dict["system"])
if soft:
# Do the same as in a basic state dict setting
for hook in self.simulator_hooks:
if hook.__class__ in state_dict["simulator_hooks"]:
hook.load_state_dict(state_dict["simulator_hooks"][hook.__class__])
else:
# Hard restart, require all thermostat to be there
for hook in self.simulator_hooks:
# Check if hook is thermostat
if hasattr(hook, "temperature_bath"):
if hook.__class__ not in state_dict["simulator_hooks"]:
raise ValueError(
f"Could not find restart information for {hook.__class__} in state dict."
)
else:
hook.load_state_dict(
state_dict["simulator_hooks"][hook.__class__]
)
# In this case, set restart flag automatically
self.restart = True