Source code for md.simulation_hooks.callback_hooks

"""
This module contains different hooks for monitoring the simulation and checkpointing.
"""

from __future__ import annotations
from typing import Union, List, Dict, Tuple, Any
from typing import TYPE_CHECKING

import schnetpack.units

if TYPE_CHECKING:
    from schnetpack.md import System
    from schnetpack.md import Simulator

import torch
import json
import os
import h5py
import numpy as np

from schnetpack.md.simulation_hooks import SimulationHook


[docs]class Checkpoint(SimulationHook): """ Hook for writing out checkpoint files containing the state_dict of the simulator. Used to restart the simulation from a previous step of previous system configuration. Args: checkpoint_file (str): Name of the file used to store the state_dict periodically. every_n_steps (int): Frequency with which checkpoint files are written. """ def __init__(self, checkpoint_file: str, every_n_steps: int): super(Checkpoint, self).__init__() self.every_n_steps = every_n_steps self.checkpoint_file = checkpoint_file def on_step_finalize(self, simulator: Simulator): """ Store state_dict at specified intervals. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ if simulator.step % self.every_n_steps == 0: torch.save(simulator.state_dict, self.checkpoint_file) def on_simulation_end(self, simulator: Simulator): """ Store state_dict at the end of the simulation. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ torch.save(simulator.state_dict, self.checkpoint_file)
[docs]class DataStream: """ Basic DataStream class to be used with the FileLogger. Creates data groups in the main hdf5 file, accumulates the associated information and flushes them to the file periodically. Args: group_name (str): Name of the data group entry. """ def __init__(self, group_name: str): self.group_name = group_name self.precision = None self.buffer = None self.data_group = None self.main_dataset = None self.buffer_size = None self.restart = None self.every_n_steps = None @staticmethod def _precision(precision: int): try: return getattr(np, f"float{precision}") except AttributeError: raise AttributeError(f"Unknown float precision {precision}") def init_data_stream( self, simulator: Simulator, main_dataset, buffer_size: int, restart: bool = False, every_n_steps: int = 1, precision: int = 32, ): """ Wrapper for initializing the data containers based on the instructions provided in the current simulator. For every data stream, the current number of valid entries is stored, which is updated periodically. This is necessary if a simulation is e.g. restarted or data is extracted during a running simulations, as all arrays are initially constructed taking the full length of the simulation into account. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. main_dataset (h5py.File): Main h5py dataset object. buffer_size (int): Size of the buffer, once full, data is stored to the hdf5 dataset. restart (bool): If the simulation is restarted, continue logging in the previously created dataset. (default=False) every_n_steps (int): How often simulation steps are logged. Used e.g. to determine overall time step in MoleculeStream. precision (int): Precision used for storing data """ self.main_dataset = main_dataset self.buffer_size = buffer_size self.restart = restart self.every_n_steps = every_n_steps self.precision = self._precision(precision) self._init_data_stream(simulator) # Write number of meaningful entries into attributes if not self.restart: self.data_group.attrs["entries"] = 0 def _init_data_stream(self, simulator: Simulator): """ Specific initialization routine. Needs to be adapted. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ raise NotImplementedError def update_buffer(self, buffer_position: int, simulator: Simulator): """ Instructions for updating the buffer. Needs to take into account reformatting of data, etc. Args: buffer_position (int): Current position in the buffer. simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ raise NotImplementedError def flush_buffer(self, file_position: int, buffer_position: int): """ Write data contained in buffer into the main hdf5 file. Args: file_position (int): Current position in the main dataset file. buffer_position (int): Most recent entry in the buffer. Used to ensure no buffer entries are written to the main file. """ self.data_group[file_position : file_position + buffer_position] = ( self.buffer[:buffer_position].detach().cpu() ) # Update number of meaningful entries self.data_group.attrs.modify("entries", file_position + buffer_position) self.data_group.flush() def _setup_data_groups(self, data_shape: Tuple[Any, int], simulator: Simulator): """ Auxiliary routine for initializing data groups in the main hdf5 data file as well as the buffer used during logging. All arrays are initialized using the full number of simulation steps specified in the main simulator class. The current positions in these arrays are managed via the 'entries' group attribute. Args: data_shape (list(int)): Shape of the target data tensor simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ # Initialize the buffer self.buffer = torch.zeros( self.buffer_size, *data_shape, device=simulator.system.device, dtype=simulator.system.dtype, ) if self.restart: # Load previous data stream and resize self.data_group = self.main_dataset[self.group_name] self.data_group.resize( (simulator.n_steps + self.data_group.attrs["entries"],) + data_shape ) else: # Otherwise, generate new data group in the dataset self.data_group = self.main_dataset.create_dataset( self.group_name, shape=(simulator.n_steps,) + data_shape, chunks=self.buffer.shape, dtype=self.precision, maxshape=(None,) + data_shape, )
[docs]class MoleculeStream(DataStream): """ DataStream for logging atom types, positions and velocities to the group 'molecules' of the main hdf5 dataset. Positions and velocities are stored in a n_steps x n_replicas x n_molecules x 6 array, where n_steps is the number of simulation steps, n_replicas and n_molecules are the number of simulation replicas and different molecules. The first 3 of the final 6 components are the Cartesian positions and the last 3 the velocities in atomic units. Atom types, the numbers of replicas, molecules and atoms, as well as the length of the time step in atomic units (for spectra) are stored in the group attributes. Args: store_velocities (bool): store atoms velocities in addition to positions """ def __init__(self, store_velocities: bool): super(MoleculeStream, self).__init__("molecules") self.store_velocities = store_velocities self.cells = False self.written = 0 def _init_data_stream(self, simulator: Simulator): """ Initialize the main data shape and write information on atom types, the numbers of replicas, molecules and atoms, as well as the length of the time step in atomic units to the group attributes. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ # Account for potential energy and positions data_dimension = ( simulator.system.n_molecules + simulator.system.total_n_atoms * 3 ) # If requested, also store velocities if self.store_velocities: data_dimension = data_dimension + simulator.system.total_n_atoms * 3 # Account for presence of simulation cells and stress tensors if not torch.any(simulator.system.volume == 0.0): self.cells = True data_dimension = data_dimension + simulator.system.n_molecules * 2 * 9 # self.energy = torch.zeros(self.n_replicas, self.n_molecules, 1) # self.stress = torch.zeros(self.n_replicas, self.n_molecules, 3, 3 data_shape = (simulator.system.n_replicas, data_dimension) self._setup_data_groups(data_shape, simulator) if not self.restart: self.data_group.attrs["n_replicas"] = simulator.system.n_replicas self.data_group.attrs["n_molecules"] = simulator.system.n_molecules self.data_group.attrs["total_n_atoms"] = simulator.system.total_n_atoms self.data_group.attrs["n_atoms"] = simulator.system.n_atoms.cpu() self.data_group.attrs["atom_types"] = simulator.system.atom_types.cpu() self.data_group.attrs["masses"] = simulator.system.masses.cpu()[ 0, :, 0 ] # Squeeze to shape of Z self.data_group.attrs["pbc"] = simulator.system.pbc.cpu()[ 0 ] # Remove training broadcast dimension self.data_group.attrs["has_cells"] = self.cells self.data_group.attrs["has_velocities"] = self.store_velocities self.data_group.attrs["time_step"] = ( simulator.integrator.time_step * self.every_n_steps ) self.data_group.attrs["every_n_steps"] = self.every_n_steps def update_buffer(self, buffer_position: int, simulator: Simulator): """ Routine for updating the buffer. Args: buffer_position (int): Current position in the buffer. simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ # Store energies start = 0 stop = simulator.system.n_molecules self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( simulator.system.energy.view(simulator.system.n_replicas, -1).detach() ) # Store positions start = stop stop += simulator.system.total_n_atoms * 3 self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( simulator.system.positions.view(simulator.system.n_replicas, -1).detach() ) if self.store_velocities: start = stop stop += simulator.system.total_n_atoms * 3 self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( simulator.system.velocities.view( simulator.system.n_replicas, -1 ).detach() ) if self.cells: # Get cells start = stop stop += 9 * simulator.system.n_molecules self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( simulator.system.cells.view(simulator.system.n_replicas, -1).detach() ) # Get stress tensors start = stop stop += 9 * simulator.system.n_molecules self.buffer[buffer_position : buffer_position + 1, :, start:stop] = ( simulator.system.stress.view(simulator.system.n_replicas, -1).detach() )
[docs]class PropertyStream(DataStream): """ Main routine for logging the properties predicted by the calculator to the group 'properties' of hdf5 dataset. Stores properties in a flattened array and writes names, shapes and positions to the group data section. Since this routine determines property shapes based on the system.properties dictionary, at least one computations needs to be performed beforehand. Properties are stored in an array of the shape n_steps x n_replicas x n_molecules x n_properties, where n_steps is the number of simulation steps, n_replicas and n_molecules is the number of simulation replicas and different molecules and n_properties is the length of the flattened property array. Args: target_properties (list): List of properties to be written to the hdf5 database. If no list is given, defaults to None, which means all properties are stored. """ def __init__(self, target_properties: List[str] = None): super(PropertyStream, self).__init__("properties") self.n_replicas = None self.n_molecules = None self.n_atoms = None self.properties_slices = {} self.target_properties = target_properties def _init_data_stream(self, simulator: Simulator): """ Routine for determining the present properties and their respective shapes based on the simulator.system.properties dictionary and storing them into the attributes of the hdf5 data group. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ self.n_replicas = simulator.system.n_replicas self.n_molecules = simulator.system.n_molecules if simulator.system.properties is None: raise FileLoggerError( "Shape of properties could not be determined, please call calculator" ) # Determine present properties, order and shape thereof ( properties_entries, properties_shape, properties_positions, ) = self._get_properties_structures(simulator.system.properties) # Set up storage data_shape = (self.n_replicas, properties_entries) self._setup_data_groups(data_shape, simulator) if not self.restart: # Store metadata on shape and position of properties in array self.data_group.attrs["shapes"] = json.dumps(properties_shape) self.data_group.attrs["positions"] = json.dumps(properties_positions) self.data_group.attrs["n_replicas"] = simulator.system.n_replicas self.data_group.attrs["n_molecules"] = simulator.system.n_molecules self.data_group.attrs["n_atoms"] = simulator.system.n_atoms.cpu() def update_buffer(self, buffer_position: int, simulator: Simulator): """ Routine for updating the property buffer. Args: buffer_position (int): Current position in the buffer. simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ # TODO: see why this detach is needed. Properties in model no buffer? # These are already detached in the calculator by default. for p in self.properties_slices: self.buffer[ buffer_position : buffer_position + 1, :, self.properties_slices[p] ] = ( simulator.system.properties[p].contiguous().view(self.n_replicas, -1) ).detach() def _get_properties_structures(self, property_dict: Dict[str, torch.tensor]): """ Auxiliary function to get the names, shapes and positions used in the property stream based on the property dictionary of the system. Args: property_dict (dict(torch.Tensor)): Property dictionary of the main simulator.system class. Returns: int: Total number of property fields used per replica, molecule and time step. dict(slice): Dictionary holding the position of the target property within the flattened array. dist(tuple): Dictionary holding the original shapes of the property tensors. """ properties_entries = 0 properties_shape = {} properties_positions = {} # If no target properties are given, use everything in system properties if self.target_properties is None: self.target_properties = list(property_dict.keys()) for p in self.target_properties: if p not in property_dict: raise FileLoggerError( "Property {:s} not found in system properties".format(p) ) # Store shape for metadata properties_shape[p] = [int(i) for i in property_dict[p].shape[1:]] # Use shape to determine overall array dimensions start = properties_entries properties_entries += int(np.prod(properties_shape[p])) # Get position of property in array properties_positions[p] = (start, properties_entries) self.properties_slices[p] = slice(start, properties_entries) return properties_entries, properties_shape, properties_positions
class FileLoggerError(Exception): """ Exception for the FileLogger class. """ pass
[docs]class FileLogger(SimulationHook): """ Class for monitoring the simulation and storing the resulting data to a hfd5 dataset. The properties to monitor are given via instances of the DataStream class. Uses buffers of a given size, which are accumulated and fushed to the main file in regular intervals in order to reduce I/O overhead. All arrays are initialized for the full number of requested simulation steps, the current positions in each data group is handled via the 'entries' attribute. Args: filename (str): Path to the hdf5 database file. buffer_size (int): Size of the buffer, once full, data is stored to the hdf5 dataset. data_streams list(schnetpack.simulation_hooks.DataStream): List of DataStreams used to collect and log information to the main hdf5 dataset, default are properties and molecules. every_n_steps (int): Frequency with which the buffer is updated. precision (int): Precision used for storing float data (16, 32, 64 bit, default 32). """ def __init__( self, filename: str, buffer_size: int, data_streams: List[DataStream] = [], every_n_steps: int = 1, precision: int = 32, ): super(FileLogger, self).__init__() self.every_n_steps = every_n_steps self.filename = filename self.buffer_size = buffer_size self.precision = precision # Create an empty variable to hold the HDF5 file upon initialization self.file = None # Precondition data streams self.data_steams = [] for stream in data_streams: self.data_steams += [stream] # Counter for file writes self.file_position = 0 self.buffer_position = 0 def on_simulation_start(self, simulator: Simulator): """ Initializes all present data streams (creating groups, determining buffer shapes, storing metadata, etc.). In addition, the 'entries' attribute of each data stream is read from the existing data set upon restart. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ # Flag, if new database should be started or data appended to old one append_data = False # Check, whether file already exists if os.path.exists(self.filename): # If file exists and it is the first call of a simulator without restart, # raise and error. if (not simulator.restart) and (simulator.effective_steps == 0): raise FileLoggerError( "File {:s} already exists and simulation was not restarted.".format( self.filename ) ) # If either a restart is requested or the simulator has already been called, # append to file if it exists. if simulator.restart or (simulator.effective_steps > 0): append_data = True else: # If no file is found, automatically generate new one. append_data = False # Create the HDF5 file self.file = h5py.File(self.filename, "a", libver="latest") # Construct stream buffers and data groups for stream in self.data_steams: stream.init_data_stream( simulator, self.file, self.buffer_size, restart=append_data, every_n_steps=self.every_n_steps, precision=self.precision, ) # Upon restart, get current position in file if append_data: self.file_position = stream.data_group.attrs["entries"] # Enable single writer, multiple reader flag self.file.swmr_mode = True def on_step_finalize(self, simulator: Simulator): """ Update the buffer of each stream after each specified interval and flush the buffer to the main file if full. Args: simulator (schnetpack.Simulator): Simulator class used in the molecular dynamics simulation. """ if simulator.step % self.every_n_steps == 0: # If buffers are full, write to file if self.buffer_position == self.buffer_size: self._write_buffer() # Update stream buffers for stream in self.data_steams: stream.update_buffer(self.buffer_position, simulator) self.buffer_position += 1 def on_simulation_end(self, simulator: Simulator): """ Perform one final flush of the buffers and close the file upon the end of the simulation. Args: simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation. """ # Flush remaining data in buffer if self.buffer_position != 0: self._write_buffer() # Close database file self.file.close() def _write_buffer(self): """ Write all current buffers to the database file. """ for stream in self.data_steams: stream.flush_buffer(self.file_position, self.buffer_position) self.file_position += self.buffer_size self.buffer_position = 0
class TensorBoardLoggerError(Exception): pass
[docs]class BasicTensorboardLogger(SimulationHook): """ Base class for logging scalar information of the system replicas and molecules collected during the simulation to TensorBoard. An individual scalar is created for every molecule, replica and property. Args: log_file (str): Path to the TensorBoard file. every_n_steps (int): Frequency with which data is logged to TensorBoard. """ def __init__(self, log_file, every_n_steps=100): super(BasicTensorboardLogger, self).__init__() from tensorboardX import SummaryWriter self.log_file = log_file self.every_n_steps = every_n_steps self.writer = SummaryWriter(self.log_file) self.n_replicas = None self.n_molecules = None def on_simulation_start(self, simulator): """ Extract the number of molecules and replicas from simulator.system upon simulation start. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ self.n_replicas = simulator.system.n_replicas self.n_molecules = simulator.system.n_molecules def on_step_finalize(self, simulator: Simulator): """ Routine for collecting and storing scalar properties of replicas and molecules during the simulation. Needs to be adapted based on the properties. In the easiest case, information on group names, etc. is passed to the self._log_group auxiliary function. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ raise NotImplementedError def _log_group(self, group_name, step, property, property_centroid=None): """ Auxiliary routine for logging the scalar data associated with the target property. An individual entry is created for every replica and molecule. If requested, an entry corresponding to the systems centroid is also created. Args: group_name (str): Base name of the property group to log. step (int): Current simulation step. property (torch.Tensor): Tensor of the shape (n_replicas x n_molecules) holding the scalar properties of each replica and molecule. property_centroid (torch.Tensor): Also store the centroid of the monitored property if provided (default=None). """ logger_dict = {} for molecule in range(self.n_molecules): mol_name = "{:s}/molecule_{:02d}".format(group_name, molecule + 1) if property_centroid is not None: logger_dict["centroid"] = property_centroid[0, molecule] for replica in range(self.n_replicas): rep_name = "r{:02d}".format(replica + 1) logger_dict[rep_name] = property[replica, molecule] self.writer.add_scalars(mol_name, logger_dict, step) def on_simulation_end(self, simulator): """ Close the TensorBoard logger. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ self.writer.close()
[docs]class TensorBoardLogger(BasicTensorboardLogger): """ TensorBoard logging hook for the properties of the replicas, as well as of the corresponding centroids for each molecule in the system container. Args: log_file (str): Path to the TensorBoard file. every_n_steps (int): Frequency with which data is logged to TensorBoard. """ def __init__(self, log_file: str, properties: List, every_n_steps: int = 100): super(TensorBoardLogger, self).__init__(log_file, every_n_steps=every_n_steps) # Instructions of how to compute properties self.get_properties = { "energy": self._get_energies, "temperature": self._get_temperature, "pressure": self._get_pressure, "volume": self._get_volume, } for p in properties: if p not in self.get_properties: raise TensorBoardLoggerError("Property '{:s}' not available.".format(p)) self.properties = properties def on_step_finalize(self, simulator: Simulator): """ Log the systems properties the given intervals. Args: simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation. """ if simulator.step % self.every_n_steps == 0: # Use the _log_group routine to log the systems temperatures log = {} for p in self.properties: log.update(self.get_properties[p](simulator.system)) for group in log: self._log_group( group, simulator.step, log[group][0], property_centroid=log[group][1], ) @staticmethod def _get_temperature(system: System): """ Instructions for obtaining temperature and centroid temperature. Args: system (schnetpack.md.System): System class. Returns: Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid. """ temperature = system.temperature temperature_centroid = system.centroid_temperature log = {"temperature": (temperature, temperature_centroid)} return log @staticmethod def _get_energies(system: System): """ Instructions for obtaining kinetic, potential and total energy. If the potential energy has not been requested explicitly in the calculator (`energy_key`) it will be constantly 0. Args: system (schnetpack.md.System): System class. Returns: Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid. """ kinetic_energy = system.kinetic_energy kinetic_energy_centroid = system.centroid_kinetic_energy potential_energy = system.potential_energy potential_energy_centroid = system.centroid_potential_energy log = { "kinetic_energy": (kinetic_energy, kinetic_energy_centroid), "potential_energy": (potential_energy, potential_energy_centroid), "total_energy": ( kinetic_energy + potential_energy, kinetic_energy_centroid + potential_energy_centroid, ), } return log @staticmethod def _get_volume(system: System): """ Instructions for obtaining the volume. Args: system (schnetpack.md.System): System class. Returns: Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid. """ volume = system.volume log = {"volume": (volume, None)} return log @staticmethod def _get_pressure(system: System): """ Instructions for obtaining pressure. Args: system (schnetpack.md.System): System class. Returns: Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid. """ pressure = ( system.compute_pressure(kinetic_component=True) / schnetpack.units.bar ) pressure_centroid = ( system.compute_centroid_pressure(kinetic_component=True) / schnetpack.units.bar ) log = {"pressure": (pressure, pressure_centroid)} return log