Source code for md.simulator

"""
All molecular dynamics in SchNetPack is performed using the :obj:`schnetpack.md.Simulator` class.
This class collects the atomistic system (:obj:`schnetpack.md.System`), calculators (:obj:`schnetpack.md.calculators`),
integrators (:obj:`schnetpack.md.integrators`) and various simulation hooks (:obj:`schnetpack.md.simulation_hooks`)
and performs the time integration.
"""

import torch
import torch.nn as nn
from contextlib import nullcontext

from tqdm import trange

from schnetpack.md import System

__all__ = ["Simulator"]


[docs]class Simulator(nn.Module): """ Main driver of the molecular dynamics simulation. Uses an integrator to propagate the molecular system defined in the system class according to the forces yielded by a provided calculator. In addition, hooks can be applied at five different stages of each simulation step: - Start of the simulation (e.g. for initializing thermostat) - Before first integrator half step (e.g. thermostat) - After computation of the forces and before main integrator step (e.g. for accelerated MD) - After second integrator half step (e.g. thermostat, output routines) - At the end of the simulation (e.g. general wrap up of file writes, etc.) This routine has a state dict which can be used to restart a previous simulation. Args: system (schnetpack.md.System): Instance of the system class defined in molecular_dynamics.system holding the structures, masses, atom type, momenta, forces and properties of all molecules and their replicas integrator (schnetpack.md.Integrator): Integrator for propagating the molecular dynamics simulation, defined in schnetpack.md.integrators calculator (schnetpack.md.calculator): Calculator class used to compute molecular forces for propagation and (if requested) various other properties. simulator_hooks (list(object)): List of different hooks to be applied during simulations. Examples would be file loggers and thermostat. step (int): Index of the initial simulation step. restart (bool): Indicates, whether the simulation is restarted. E.g. if set to True, the simulator tries to continue logging in the previously created dataset. (default=False) This is set automatically by the restart_simulation function. Enabling it without the function currently only makes sense if independent simulations should be written to the same file. progress (bool): show progress bar during simulation. Can be deactivated e.g. for cluster runs. """ def __init__( self, system: System, integrator, calculator, simulator_hooks: list = [], step: int = 0, restart: bool = False, gradients_required: bool = False, progress: bool = True, ): super(Simulator, self).__init__() self.system = system self.integrator = integrator self.calculator = calculator self.simulator_hooks = torch.nn.ModuleList(simulator_hooks) self.step = step self.n_steps = None self.restart = restart self.gradients_required = gradients_required self.progress = progress # Keep track of the actual simulation steps performed with simulate calls self.effective_steps = 0 @property def device(self): return self.system.device @property def dtype(self): return self.system.dtype def simulate(self, n_steps: int): """ Main simulation function. Propagates the system for a certain number of steps. Args: n_steps (int): Number of simulation steps to be performed. """ self.n_steps = n_steps # Determine iterator if self.progress: iterator = trange else: iterator = range # Check, if computational graph should be built if self.gradients_required: grad_context = torch.no_grad() else: grad_context = nullcontext() with grad_context: # Perform initial computation of forces self.calculator.calculate(self.system) # Call hooks at the simulation start for hook in self.simulator_hooks: hook.on_simulation_start(self) for _ in iterator(n_steps): # Call hook before first half step for hook in self.simulator_hooks: hook.on_step_begin(self) # Do half step momenta self.integrator.half_step(self.system) # Do propagation MD/PIMD self.integrator.main_step(self.system) # Compute new forces self.calculator.calculate(self.system) # Call hook after forces for hook in self.simulator_hooks: hook.on_step_middle(self) # Do half step momenta self.integrator.half_step(self.system) # Call hooks after second half step # Hooks are called in reverse order to guarantee symmetry of # the propagator when using thermostat and barostats for hook in self.simulator_hooks[::-1]: hook.on_step_end(self) # Logging hooks etc for hook in self.simulator_hooks: hook.on_step_finalize(self) self.step += 1 self.effective_steps += 1 # Call hooks at the simulation end for hook in self.simulator_hooks: hook.on_simulation_end(self) @property def state_dict(self): """ State dict used to restart the simulation. Generates a dictionary with the following entries: - step: current simulation step - systems: state dict of the system holding current positions, momenta, forces, etc... - simulator_hooks: dict of state dicts of the various hooks used during simulation using their basic class name as keys. Returns: dict: State dict containing the current step, the system parameters (positions, momenta, etc.) and all simulator_hook state dicts """ state_dict = { "step": self.step, "system": self.system.state_dict(), "simulator_hooks": { hook.__class__: hook.state_dict() for hook in self.simulator_hooks }, } return state_dict @state_dict.setter def state_dict(self, state_dict): """ Set the current state dict of the simulator using a state dict defined in state_dict. This routine assumes, that the identity of all hooks has not changed and the order is preserved. A more general method to restart simulations is provided below. Args: state_dict (dict): state dict containing the entries 'step', 'simulator_hooks' and 'system'. """ self.step = state_dict["step"] self.system.load_state_dict(state_dict["system"]) # Set state dicts of all hooks for hook in self.simulator_hooks: if hook.__class__ in state_dict["simulator_hooks"]: hook.load_state_dict(state_dict["simulator_hooks"][hook.__class__]) def restart_simulation(self, state_dict, soft=False): """ Routine for restarting a simulation. Reads the current step, as well as system state from the provided state dict. In case of the simulation hooks, only the states of the thermostat hooks are restored, as all other hooks do not depend on previous simulations. If the soft option is chosen, only restores states of thermostat if they are present in the current simulation and the state dict. Otherwise, all thermostat found in the state dict are required to be present in the current simulation. Args: state_dict (dict): State dict of the current simulation soft (bool): Flag to toggle hard/soft thermostat restarts ( default=False) """ # TODO: restart with metadynamics hooks etc, ? self.step = state_dict["step"] self.system.load_system_state(state_dict["system"]) if soft: # Do the same as in a basic state dict setting for hook in self.simulator_hooks: if hook.__class__ in state_dict["simulator_hooks"]: hook.load_state_dict(state_dict["simulator_hooks"][hook.__class__]) else: # Hard restart, require all thermostat to be there for hook in self.simulator_hooks: # Check if hook is thermostat if hasattr(hook, "temperature_bath"): if hook.__class__ not in state_dict["simulator_hooks"]: raise ValueError( f"Could not find restart information for {hook.__class__} in state dict." ) else: hook.load_state_dict( state_dict["simulator_hooks"][hook.__class__] ) # In this case, set restart flag automatically self.restart = True