"""
This module contains different hooks for monitoring the simulation and checkpointing.
"""
from __future__ import annotations
from typing import Union, List, Dict, Tuple, Any
from typing import TYPE_CHECKING
import schnetpack.units
if TYPE_CHECKING:
from schnetpack.md import System
from schnetpack.md import Simulator
import torch
import json
import os
import h5py
import numpy as np
from schnetpack.md.simulation_hooks import SimulationHook
[docs]class Checkpoint(SimulationHook):
"""
Hook for writing out checkpoint files containing the state_dict of the simulator. Used to restart the simulation
from a previous step of previous system configuration.
Args:
checkpoint_file (str): Name of the file used to store the state_dict periodically.
every_n_steps (int): Frequency with which checkpoint files are written.
"""
def __init__(self, checkpoint_file: str, every_n_steps: int):
super(Checkpoint, self).__init__()
self.every_n_steps = every_n_steps
self.checkpoint_file = checkpoint_file
def on_step_finalize(self, simulator: Simulator):
"""
Store state_dict at specified intervals.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
if simulator.step % self.every_n_steps == 0:
torch.save(simulator.state_dict, self.checkpoint_file)
def on_simulation_end(self, simulator: Simulator):
"""
Store state_dict at the end of the simulation.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
torch.save(simulator.state_dict, self.checkpoint_file)
[docs]class DataStream:
"""
Basic DataStream class to be used with the FileLogger. Creates data groups in the main hdf5 file, accumulates
the associated information and flushes them to the file periodically.
Args:
group_name (str): Name of the data group entry.
"""
def __init__(self, group_name: str):
self.group_name = group_name
self.precision = None
self.buffer = None
self.data_group = None
self.main_dataset = None
self.buffer_size = None
self.restart = None
self.every_n_steps = None
@staticmethod
def _precision(precision: int):
try:
return getattr(np, f"float{precision}")
except AttributeError:
raise AttributeError(f"Unknown float precision {precision}")
def init_data_stream(
self,
simulator: Simulator,
main_dataset,
buffer_size: int,
restart: bool = False,
every_n_steps: int = 1,
precision: int = 32,
):
"""
Wrapper for initializing the data containers based on the instructions provided in the current simulator. For
every data stream, the current number of valid entries is stored, which is updated periodically. This is
necessary if a simulation is e.g. restarted or data is extracted during a running simulations, as all arrays
are initially constructed taking the full length of the simulation into account.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
main_dataset (h5py.File): Main h5py dataset object.
buffer_size (int): Size of the buffer, once full, data is stored to the hdf5 dataset.
restart (bool): If the simulation is restarted, continue logging in the previously created dataset.
(default=False)
every_n_steps (int): How often simulation steps are logged. Used e.g. to determine overall time step in
MoleculeStream.
precision (int): Precision used for storing data
"""
self.main_dataset = main_dataset
self.buffer_size = buffer_size
self.restart = restart
self.every_n_steps = every_n_steps
self.precision = self._precision(precision)
self._init_data_stream(simulator)
# Write number of meaningful entries into attributes
if not self.restart:
self.data_group.attrs["entries"] = 0
def _init_data_stream(self, simulator: Simulator):
"""
Specific initialization routine. Needs to be adapted.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
raise NotImplementedError
def update_buffer(self, buffer_position: int, simulator: Simulator):
"""
Instructions for updating the buffer. Needs to take into account reformatting of data, etc.
Args:
buffer_position (int): Current position in the buffer.
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
raise NotImplementedError
def flush_buffer(self, file_position: int, buffer_position: int):
"""
Write data contained in buffer into the main hdf5 file.
Args:
file_position (int): Current position in the main dataset file.
buffer_position (int): Most recent entry in the buffer. Used to ensure no buffer entries are written to the
main file.
"""
self.data_group[file_position : file_position + buffer_position] = (
self.buffer[:buffer_position].detach().cpu()
)
# Update number of meaningful entries
self.data_group.attrs.modify("entries", file_position + buffer_position)
self.data_group.flush()
def _setup_data_groups(self, data_shape: Tuple[Any, int], simulator: Simulator):
"""
Auxiliary routine for initializing data groups in the main hdf5 data file as well as the buffer used during
logging. All arrays are initialized using the full number of simulation steps specified in the main simulator
class. The current positions in these arrays are managed via the 'entries' group attribute.
Args:
data_shape (list(int)): Shape of the target data tensor
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# Initialize the buffer
self.buffer = torch.zeros(
self.buffer_size,
*data_shape,
device=simulator.system.device,
dtype=simulator.system.dtype,
)
if self.restart:
# Load previous data stream and resize
self.data_group = self.main_dataset[self.group_name]
self.data_group.resize(
(simulator.n_steps + self.data_group.attrs["entries"],) + data_shape
)
else:
# Otherwise, generate new data group in the dataset
self.data_group = self.main_dataset.create_dataset(
self.group_name,
shape=(simulator.n_steps,) + data_shape,
chunks=self.buffer.shape,
dtype=self.precision,
maxshape=(None,) + data_shape,
)
[docs]class MoleculeStream(DataStream):
"""
DataStream for logging atom types, positions and velocities to the group 'molecules' of the main hdf5 dataset.
Positions and velocities are stored in a n_steps x n_replicas x n_molecules x 6 array, where n_steps is the number
of simulation steps, n_replicas and n_molecules are the number of simulation replicas and different molecules. The
first 3 of the final 6 components are the Cartesian positions and the last 3 the velocities in atomic units. Atom
types, the numbers of replicas, molecules and atoms, as well as the length of the time step in atomic units
(for spectra) are stored in the group attributes.
Args:
store_velocities (bool): store atoms velocities in addition to positions
"""
def __init__(self, store_velocities: bool):
super(MoleculeStream, self).__init__("molecules")
self.store_velocities = store_velocities
self.cells = False
self.written = 0
def _init_data_stream(self, simulator: Simulator):
"""
Initialize the main data shape and write information on atom types, the numbers of replicas, molecules and
atoms, as well as the length of the time step in atomic units to the group attributes.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# Account for potential energy and positions
data_dimension = (
simulator.system.n_molecules + simulator.system.total_n_atoms * 3
)
# If requested, also store velocities
if self.store_velocities:
data_dimension = data_dimension + simulator.system.total_n_atoms * 3
# Account for presence of simulation cells and stress tensors
if not torch.any(simulator.system.volume == 0.0):
self.cells = True
data_dimension = data_dimension + simulator.system.n_molecules * 2 * 9
# self.energy = torch.zeros(self.n_replicas, self.n_molecules, 1)
# self.stress = torch.zeros(self.n_replicas, self.n_molecules, 3, 3
data_shape = (simulator.system.n_replicas, data_dimension)
self._setup_data_groups(data_shape, simulator)
if not self.restart:
self.data_group.attrs["n_replicas"] = simulator.system.n_replicas
self.data_group.attrs["n_molecules"] = simulator.system.n_molecules
self.data_group.attrs["total_n_atoms"] = simulator.system.total_n_atoms
self.data_group.attrs["n_atoms"] = simulator.system.n_atoms.cpu()
self.data_group.attrs["atom_types"] = simulator.system.atom_types.cpu()
self.data_group.attrs["masses"] = simulator.system.masses.cpu()[
0, :, 0
] # Squeeze to shape of Z
self.data_group.attrs["pbc"] = simulator.system.pbc.cpu()[
0
] # Remove training broadcast dimension
self.data_group.attrs["has_cells"] = self.cells
self.data_group.attrs["has_velocities"] = self.store_velocities
self.data_group.attrs["time_step"] = (
simulator.integrator.time_step * self.every_n_steps
)
self.data_group.attrs["every_n_steps"] = self.every_n_steps
def update_buffer(self, buffer_position: int, simulator: Simulator):
"""
Routine for updating the buffer.
Args:
buffer_position (int): Current position in the buffer.
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# Store energies
start = 0
stop = simulator.system.n_molecules
self.buffer[buffer_position : buffer_position + 1, :, start:stop] = (
simulator.system.energy.view(simulator.system.n_replicas, -1).detach()
)
# Store positions
start = stop
stop += simulator.system.total_n_atoms * 3
self.buffer[buffer_position : buffer_position + 1, :, start:stop] = (
simulator.system.positions.view(simulator.system.n_replicas, -1).detach()
)
if self.store_velocities:
start = stop
stop += simulator.system.total_n_atoms * 3
self.buffer[buffer_position : buffer_position + 1, :, start:stop] = (
simulator.system.velocities.view(
simulator.system.n_replicas, -1
).detach()
)
if self.cells:
# Get cells
start = stop
stop += 9 * simulator.system.n_molecules
self.buffer[buffer_position : buffer_position + 1, :, start:stop] = (
simulator.system.cells.view(simulator.system.n_replicas, -1).detach()
)
# Get stress tensors
start = stop
stop += 9 * simulator.system.n_molecules
self.buffer[buffer_position : buffer_position + 1, :, start:stop] = (
simulator.system.stress.view(simulator.system.n_replicas, -1).detach()
)
[docs]class PropertyStream(DataStream):
"""
Main routine for logging the properties predicted by the calculator to the group 'properties' of hdf5 dataset.
Stores properties in a flattened array and writes names, shapes and positions to the group data section. Since this
routine determines property shapes based on the system.properties dictionary, at least one computations needs to be
performed beforehand. Properties are stored in an array of the shape
n_steps x n_replicas x n_molecules x n_properties, where n_steps is the number of simulation steps, n_replicas and
n_molecules is the number of simulation replicas and different molecules and n_properties is the length of the
flattened property array.
Args:
target_properties (list): List of properties to be written to the hdf5 database. If no list is given, defaults
to None, which means all properties are stored.
"""
def __init__(self, target_properties: List[str] = None):
super(PropertyStream, self).__init__("properties")
self.n_replicas = None
self.n_molecules = None
self.n_atoms = None
self.properties_slices = {}
self.target_properties = target_properties
def _init_data_stream(self, simulator: Simulator):
"""
Routine for determining the present properties and their respective shapes based on the
simulator.system.properties dictionary and storing them into the attributes of the hdf5 data group.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
self.n_replicas = simulator.system.n_replicas
self.n_molecules = simulator.system.n_molecules
if simulator.system.properties is None:
raise FileLoggerError(
"Shape of properties could not be determined, please call calculator"
)
# Determine present properties, order and shape thereof
(
properties_entries,
properties_shape,
properties_positions,
) = self._get_properties_structures(simulator.system.properties)
# Set up storage
data_shape = (self.n_replicas, properties_entries)
self._setup_data_groups(data_shape, simulator)
if not self.restart:
# Store metadata on shape and position of properties in array
self.data_group.attrs["shapes"] = json.dumps(properties_shape)
self.data_group.attrs["positions"] = json.dumps(properties_positions)
self.data_group.attrs["n_replicas"] = simulator.system.n_replicas
self.data_group.attrs["n_molecules"] = simulator.system.n_molecules
self.data_group.attrs["n_atoms"] = simulator.system.n_atoms.cpu()
def update_buffer(self, buffer_position: int, simulator: Simulator):
"""
Routine for updating the property buffer.
Args:
buffer_position (int): Current position in the buffer.
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# TODO: see why this detach is needed. Properties in model no buffer?
# These are already detached in the calculator by default.
for p in self.properties_slices:
self.buffer[
buffer_position : buffer_position + 1, :, self.properties_slices[p]
] = (
simulator.system.properties[p].contiguous().view(self.n_replicas, -1)
).detach()
def _get_properties_structures(self, property_dict: Dict[str, torch.tensor]):
"""
Auxiliary function to get the names, shapes and positions used in the property stream based on the property
dictionary of the system.
Args:
property_dict (dict(torch.Tensor)): Property dictionary of the main simulator.system class.
Returns:
int: Total number of property fields used per replica, molecule and time step.
dict(slice): Dictionary holding the position of the target property within the flattened array.
dist(tuple): Dictionary holding the original shapes of the property tensors.
"""
properties_entries = 0
properties_shape = {}
properties_positions = {}
# If no target properties are given, use everything in system properties
if self.target_properties is None:
self.target_properties = list(property_dict.keys())
for p in self.target_properties:
if p not in property_dict:
raise FileLoggerError(
"Property {:s} not found in system properties".format(p)
)
# Store shape for metadata
properties_shape[p] = [int(i) for i in property_dict[p].shape[1:]]
# Use shape to determine overall array dimensions
start = properties_entries
properties_entries += int(np.prod(properties_shape[p]))
# Get position of property in array
properties_positions[p] = (start, properties_entries)
self.properties_slices[p] = slice(start, properties_entries)
return properties_entries, properties_shape, properties_positions
class FileLoggerError(Exception):
"""
Exception for the FileLogger class.
"""
pass
[docs]class FileLogger(SimulationHook):
"""
Class for monitoring the simulation and storing the resulting data to a hfd5 dataset. The properties to monitor are
given via instances of the DataStream class. Uses buffers of a given size, which are accumulated and fushed to the
main file in regular intervals in order to reduce I/O overhead. All arrays are initialized for the full number of
requested simulation steps, the current positions in each data group is handled via the 'entries' attribute.
Args:
filename (str): Path to the hdf5 database file.
buffer_size (int): Size of the buffer, once full, data is stored to the hdf5 dataset.
data_streams list(schnetpack.simulation_hooks.DataStream): List of DataStreams used to collect and log
information to the main hdf5 dataset, default are
properties and molecules.
every_n_steps (int): Frequency with which the buffer is updated.
precision (int): Precision used for storing float data (16, 32, 64 bit, default 32).
"""
def __init__(
self,
filename: str,
buffer_size: int,
data_streams: List[DataStream] = [],
every_n_steps: int = 1,
precision: int = 32,
):
super(FileLogger, self).__init__()
self.every_n_steps = every_n_steps
self.filename = filename
self.buffer_size = buffer_size
self.precision = precision
# Create an empty variable to hold the HDF5 file upon initialization
self.file = None
# Precondition data streams
self.data_steams = []
for stream in data_streams:
self.data_steams += [stream]
# Counter for file writes
self.file_position = 0
self.buffer_position = 0
def on_simulation_start(self, simulator: Simulator):
"""
Initializes all present data streams (creating groups, determining buffer shapes, storing metadata, etc.). In
addition, the 'entries' attribute of each data stream is read from the existing data set upon restart.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# Flag, if new database should be started or data appended to old one
append_data = False
# Check, whether file already exists
if os.path.exists(self.filename):
# If file exists and it is the first call of a simulator without restart,
# raise and error.
if (not simulator.restart) and (simulator.effective_steps == 0):
raise FileLoggerError(
"File {:s} already exists and simulation was not restarted.".format(
self.filename
)
)
# If either a restart is requested or the simulator has already been called,
# append to file if it exists.
if simulator.restart or (simulator.effective_steps > 0):
append_data = True
else:
# If no file is found, automatically generate new one.
append_data = False
# Create the HDF5 file
self.file = h5py.File(self.filename, "a", libver="latest")
# Construct stream buffers and data groups
for stream in self.data_steams:
stream.init_data_stream(
simulator,
self.file,
self.buffer_size,
restart=append_data,
every_n_steps=self.every_n_steps,
precision=self.precision,
)
# Upon restart, get current position in file
if append_data:
self.file_position = stream.data_group.attrs["entries"]
# Enable single writer, multiple reader flag
self.file.swmr_mode = True
def on_step_finalize(self, simulator: Simulator):
"""
Update the buffer of each stream after each specified interval and flush the buffer to the main file if full.
Args:
simulator (schnetpack.Simulator): Simulator class used in the molecular dynamics simulation.
"""
if simulator.step % self.every_n_steps == 0:
# If buffers are full, write to file
if self.buffer_position == self.buffer_size:
self._write_buffer()
# Update stream buffers
for stream in self.data_steams:
stream.update_buffer(self.buffer_position, simulator)
self.buffer_position += 1
def on_simulation_end(self, simulator: Simulator):
"""
Perform one final flush of the buffers and close the file upon the end of the simulation.
Args:
simulator (schnetpack.md.Simulator): Simulator class used in the molecular dynamics simulation.
"""
# Flush remaining data in buffer
if self.buffer_position != 0:
self._write_buffer()
# Close database file
self.file.close()
def _write_buffer(self):
"""
Write all current buffers to the database file.
"""
for stream in self.data_steams:
stream.flush_buffer(self.file_position, self.buffer_position)
self.file_position += self.buffer_size
self.buffer_position = 0
class TensorBoardLoggerError(Exception):
pass
[docs]class BasicTensorboardLogger(SimulationHook):
"""
Base class for logging scalar information of the system replicas and molecules collected during the simulation to
TensorBoard. An individual scalar is created for every molecule, replica and property.
Args:
log_file (str): Path to the TensorBoard file.
every_n_steps (int): Frequency with which data is logged to TensorBoard.
"""
def __init__(self, log_file, every_n_steps=100):
super(BasicTensorboardLogger, self).__init__()
from tensorboardX import SummaryWriter
self.log_file = log_file
self.every_n_steps = every_n_steps
self.writer = SummaryWriter(self.log_file)
self.n_replicas = None
self.n_molecules = None
def on_simulation_start(self, simulator):
"""
Extract the number of molecules and replicas from simulator.system upon simulation start.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
self.n_replicas = simulator.system.n_replicas
self.n_molecules = simulator.system.n_molecules
def on_step_finalize(self, simulator: Simulator):
"""
Routine for collecting and storing scalar properties of replicas and molecules during the simulation. Needs to
be adapted based on the properties.
In the easiest case, information on group names, etc. is passed to the self._log_group auxiliary function.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
raise NotImplementedError
def _log_group(self, group_name, step, property, property_centroid=None):
"""
Auxiliary routine for logging the scalar data associated with the target property. An individual entry is
created for every replica and molecule. If requested, an entry corresponding to the systems centroid is also
created.
Args:
group_name (str): Base name of the property group to log.
step (int): Current simulation step.
property (torch.Tensor): Tensor of the shape (n_replicas x n_molecules) holding the scalar properties of
each replica and molecule.
property_centroid (torch.Tensor): Also store the centroid of the monitored property if provided
(default=None).
"""
logger_dict = {}
for molecule in range(self.n_molecules):
mol_name = "{:s}/molecule_{:02d}".format(group_name, molecule + 1)
if property_centroid is not None:
logger_dict["centroid"] = property_centroid[0, molecule]
for replica in range(self.n_replicas):
rep_name = "r{:02d}".format(replica + 1)
logger_dict[rep_name] = property[replica, molecule]
self.writer.add_scalars(mol_name, logger_dict, step)
def on_simulation_end(self, simulator):
"""
Close the TensorBoard logger.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
self.writer.close()
[docs]class TensorBoardLogger(BasicTensorboardLogger):
"""
TensorBoard logging hook for the properties of the replicas, as well as of the corresponding centroids for each
molecule in the system container.
Args:
log_file (str): Path to the TensorBoard file.
every_n_steps (int): Frequency with which data is logged to TensorBoard.
"""
def __init__(self, log_file: str, properties: List, every_n_steps: int = 100):
super(TensorBoardLogger, self).__init__(log_file, every_n_steps=every_n_steps)
# Instructions of how to compute properties
self.get_properties = {
"energy": self._get_energies,
"temperature": self._get_temperature,
"pressure": self._get_pressure,
"volume": self._get_volume,
}
for p in properties:
if p not in self.get_properties:
raise TensorBoardLoggerError("Property '{:s}' not available.".format(p))
self.properties = properties
def on_step_finalize(self, simulator: Simulator):
"""
Log the systems properties the given intervals.
Args:
simulator (schnetpack.simulation_hooks.Simulator): Simulator class used in the molecular dynamics simulation.
"""
if simulator.step % self.every_n_steps == 0:
# Use the _log_group routine to log the systems temperatures
log = {}
for p in self.properties:
log.update(self.get_properties[p](simulator.system))
for group in log:
self._log_group(
group,
simulator.step,
log[group][0],
property_centroid=log[group][1],
)
@staticmethod
def _get_temperature(system: System):
"""
Instructions for obtaining temperature and centroid temperature.
Args:
system (schnetpack.md.System): System class.
Returns:
Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid.
"""
temperature = system.temperature
temperature_centroid = system.centroid_temperature
log = {"temperature": (temperature, temperature_centroid)}
return log
@staticmethod
def _get_energies(system: System):
"""
Instructions for obtaining kinetic, potential and total energy. If the potential energy has not been requested
explicitly in the calculator (`energy_key`) it will be constantly 0.
Args:
system (schnetpack.md.System): System class.
Returns:
Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid.
"""
kinetic_energy = system.kinetic_energy
kinetic_energy_centroid = system.centroid_kinetic_energy
potential_energy = system.potential_energy
potential_energy_centroid = system.centroid_potential_energy
log = {
"kinetic_energy": (kinetic_energy, kinetic_energy_centroid),
"potential_energy": (potential_energy, potential_energy_centroid),
"total_energy": (
kinetic_energy + potential_energy,
kinetic_energy_centroid + potential_energy_centroid,
),
}
return log
@staticmethod
def _get_volume(system: System):
"""
Instructions for obtaining the volume.
Args:
system (schnetpack.md.System): System class.
Returns:
Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid.
"""
volume = system.volume
log = {"volume": (volume, None)}
return log
@staticmethod
def _get_pressure(system: System):
"""
Instructions for obtaining pressure.
Args:
system (schnetpack.md.System): System class.
Returns:
Dict[Tuple[torch.tensor, torch.tensor]]: Dictionary containing tuples of property and centroid.
"""
pressure = (
system.compute_pressure(kinetic_component=True) / schnetpack.units.bar
)
pressure_centroid = (
system.compute_centroid_pressure(kinetic_component=True)
/ schnetpack.units.bar
)
log = {"pressure": (pressure, pressure_centroid)}
return log