Source code for md.system

"""
This module is used to store all information on the simulated atomistic systems.
It includes functionality for loading molecules from files.
All this functionality is encoded in the :obj:`schnetpack.md.System` class.
"""

import torch
import torch.nn as nn

from schnetpack.md.utils import NormalModeTransformer
from ase import Atoms

from typing import Union, List, OrderedDict

from schnetpack import units as spk_units
from schnetpack.md.utils import UninitializedMixin

__all__ = ["System"]


class SystemException(Exception):
    pass


class SystemWarning(Warning):
    pass


[docs]class System(UninitializedMixin, nn.Module): """ Container for all properties associated with the simulated molecular system (masses, positions, momenta, ...). Uses MD unit system defined in `schnetpack.units` internally. In order to simulate multiple systems efficiently dynamics properties (positions, momenta, forces) are torch tensors with the following dimensions: n_replicas x (n_molecules * n_atoms) x 3 Here n_replicas is the number of copies for every molecule. In a normal simulation, these are treated as independent molecules e.g. for sampling purposes. In the case of ring polymer molecular dynamics (using the RingPolymer integrator), these replicas correspond to the beads of the polymer. n_molecules is the number of different molecules constituting the system, these can e.g. be different initial configurations of the same system (once again for sampling) or completely different molecules. Atoms of multiple molecules are concatenated. Static properties are stored in tensors of the shape: n_atoms : n_molecules (the same for all replicas) masses : 1 x (n_molecules * n_atoms) x 1 (the same for all replicas) atom_types : (n_molecules * n_atoms) index_m : (n_molecules * n_atoms) `n_atoms` contains the number of atoms present in every molecule, `masses` and `atom_types` contain the molcular masses and nuclear charges. `index_m` is an index for mapping atoms to individual molecules. Finally a dictionary properties stores the results of every calculator call for easy access of e.g. energies and dipole moments. Args: device (str, torch.device): Computation device (default='cuda'). precision (int, torch.dtype): Precision used for floating point numbers (default=32). """ # Property dictionary, updated during simulation properties = {} def __init__( self, normal_mode_transform: NormalModeTransformer = NormalModeTransformer ): super(System, self).__init__() self._nm_transformer = normal_mode_transform # For initialized nm transform self.nm_transform = None # Index for aggregation self.register_uninitialized_buffer("index_m", dtype=torch.long) # number of molecules, replicas of each and vector with the number of # atoms in each molecule self.n_replicas = None self.n_molecules = None self.total_n_atoms = None # General static molecular properties self.register_uninitialized_buffer("n_atoms", dtype=torch.long) self.register_uninitialized_buffer("atom_types", dtype=torch.long) self.register_uninitialized_buffer("masses") # Dynamic properties updated during simulation self.register_uninitialized_buffer("positions") self.register_uninitialized_buffer("momenta") self.register_uninitialized_buffer("forces") self.register_uninitialized_buffer("energy") # Properties for periodic boundary conditions and crystal cells self.register_uninitialized_buffer("cells") self.register_uninitialized_buffer("pbc") self.register_uninitialized_buffer( "stress" ) # Used for the computation of the pressure # Dummy tensor for device and dtype self.register_buffer("_dd_dummy", torch.zeros(1)) @property def device(self): return self._dd_dummy.device @property def dtype(self): return self._dd_dummy.dtype def load_molecules( self, molecules: Union[Atoms, List[Atoms]], n_replicas: int = 1, position_unit_input: Union[str, float] = "Angstrom", mass_unit_input: Union[str, float] = 1.0, ): """ Initializes all required variables and tensors based on a list of ASE atoms objects. Args: molecules (ase.Atoms, list(ase.Atoms)): List of ASE atoms objects containing molecular structures and chemical elements. n_replicas (int): Number of replicas (e.g. for RPMD) position_unit_input (str, float): Position units of the input structures (default="Angstrom") mass_unit_input (str, float): Units of masses passed in the ASE atoms. Assumed to be Dalton. """ self.n_replicas = n_replicas # TODO: make cells/PBC False if not set? # Set up unit conversion positions2internal = spk_units.unit2internal(position_unit_input) mass2internal = spk_units.unit2internal(mass_unit_input) # 0) Check if molecules is a single ase.Atoms object and wrap it in list. if isinstance(molecules, Atoms): molecules = [molecules] # 1) Get number of molecules, number of replicas and number of # overall systems self.n_molecules = len(molecules) # 2) Construct array with number of atoms in each molecule self.n_atoms = torch.zeros(self.n_molecules, dtype=torch.long) for i in range(self.n_molecules): self.n_atoms[i] = molecules[i].get_global_number_of_atoms() # 3) Get total n_molecule x n_atom dimension self.total_n_atoms = torch.sum(self.n_atoms).item() # initialize index vector for aggregation self.index_m = torch.zeros(self.total_n_atoms, dtype=torch.long) # 3) Construct basic property arrays self.atom_types = torch.zeros(self.total_n_atoms, dtype=torch.long) self.masses = torch.ones(1, self.total_n_atoms, 1) # Relevant for dynamic properties: positions, momenta, forces self.positions = torch.zeros(self.n_replicas, self.total_n_atoms, 3) self.momenta = torch.zeros(self.n_replicas, self.total_n_atoms, 3) self.forces = torch.zeros(self.n_replicas, self.total_n_atoms, 3) self.energy = torch.zeros(self.n_replicas, self.n_molecules, 1) # Relevant for periodic boundary conditions and simulation cells self.cells = torch.zeros(self.n_replicas, self.n_molecules, 3, 3) self.stress = torch.zeros(self.n_replicas, self.n_molecules, 3, 3) self.pbc = torch.zeros(1, self.n_molecules, 3) # 5) Populate arrays according to the data provided in molecules idx_c = 0 for i in range(self.n_molecules): n_atoms = self.n_atoms[i] # Aggregation array self.index_m[idx_c : idx_c + n_atoms] = i # Static properties self.atom_types[idx_c : idx_c + n_atoms] = torch.from_numpy( molecules[i].get_atomic_numbers() ).long() self.masses[0, idx_c : idx_c + n_atoms, 0] = torch.from_numpy( molecules[i].get_masses() * mass2internal ) # Dynamic properties self.positions[:, idx_c : idx_c + n_atoms, :] = torch.from_numpy( molecules[i].positions * positions2internal ) # Properties for cell simulations self.cells[:, i, :, :] = torch.from_numpy( molecules[i].cell.array * positions2internal ) self.pbc[0, i, :] = torch.from_numpy(molecules[i].pbc) idx_c += n_atoms # Convert periodic boundary conditions to Boolean tensor self.pbc = self.pbc.bool() # Check for cell/pbc stuff: if torch.sum(torch.abs(self.cells)) == 0.0: if torch.sum(self.pbc) > 0.0: raise SystemException("Found periodic boundary conditions but no cell.") # Set normal mode transformer self.nm_transform = self._nm_transformer(n_replicas) def sum_atoms(self, x: torch.Tensor): """ Auxiliary routine for summing atomic contributions for each molecule. Args: x (torch.Tensor): Input tensor of the shape ( : x (n_molecules * n_atoms) x ...) Returns: torch.Tensor: Aggregated tensor of the shape ( : x n_molecules x ...) """ x_shape = x.shape x_tmp = torch.zeros( x_shape[0], self.n_molecules, *x_shape[2:], device=x.device, dtype=x.dtype ) return x_tmp.index_add(1, self.index_m, x) def _mean_atoms(self, x: torch.Tensor): """ Auxiliary routine for computing mean over atomic contributions for each molecule. Args: x (torch.Tensor): Input tensor of the shape ( : x (n_molecules * n_atoms) x ...) Returns: torch.Tensor: Aggregated tensor of the shape ( : x n_molecules x ...) """ return self.sum_atoms(x) / self.n_atoms[None, :, None] def expand_atoms(self, x: torch.Tensor): """ Auxiliary routine for expanding molecular contributions over the corresponding atoms. Args: x (torch.Tensor): Tensor of the shape ( : x n_molecules x ...) Returns: torch.Tensor: Tensor of the shape ( : x (n_molecules * n_atoms) x ...) """ return x[:, self.index_m, ...] @property def center_of_mass(self): """ Compute the center of mass for each replica and molecule Returns: torch.Tensor: n_replicas x n_molecules x 3 tensor holding the center of mass. """ # Compute center of mass center_of_mass = self.sum_atoms(self.positions * self.masses) / self.sum_atoms( self.masses ) return center_of_mass def remove_center_of_mass(self): """ Move all structures to their respective center of mass. """ self.positions -= self.expand_atoms(self.center_of_mass) def remove_translation(self): """ Remove all components in the current momenta associated with translational motion. """ self.momenta -= self.expand_atoms(self._mean_atoms(self.momenta)) def remove_com_rotation(self): """ Remove all components in the current momenta associated with rotational motion using Eckart conditions. """ # Compute the moment of inertia tensor moment_of_inertia = ( torch.sum(self.positions**2, dim=2, keepdim=True)[..., None] * torch.eye(3, dtype=self.positions.dtype, device=self.positions.device)[ None, None, :, : ] - self.positions[..., :, None] * self.positions[..., None, :] ) moment_of_inertia = self.sum_atoms(moment_of_inertia * self.masses[..., None]) # Compute the angular momentum angular_momentum = self.sum_atoms(torch.cross(self.positions, self.momenta, -1)) # Compute the angular velocities angular_velocities = torch.matmul( angular_momentum[:, :, None, :], torch.inverse(moment_of_inertia) ).squeeze(2) # Compute individual atomic contributions rotational_velocities = torch.cross( self.expand_atoms(angular_velocities), self.positions, -1 ) # Subtract rotation from overall motion (apply atom mask) self.momenta -= rotational_velocities * self.masses def get_ase_atoms(self, position_unit_output="Angstrom"): # TODO: make sensible unit conversion and update docs """ Convert the stored molecular configurations into ASE Atoms objects. This is e.g. used for the neighbor lists based on environment providers. All units are atomic units by default, as used in the calculator Args: position_unit_output (str, float): Target units for position output. Returns: list(ase.Atoms): List of ASE Atoms objects, with the replica and molecule dimension flattened. """ internal2positions = spk_units.convert_units( spk_units.length, position_unit_output ) atoms = [] for idx_r in range(self.n_replicas): idx_c = 0 for idx_m in range(self.n_molecules): n_atoms = self.n_atoms[idx_m] positions = ( self.positions[idx_r, idx_c : idx_c + n_atoms] .cpu() .detach() .numpy() ) * internal2positions atom_types = ( self.atom_types[idx_c : idx_c + n_atoms].cpu().detach().numpy() ) cell = ( self.cells[idx_r, idx_m].cpu().detach().numpy() * internal2positions ) pbc = self.pbc[0, idx_m].cpu().detach().numpy() mol = Atoms(atom_types, positions, cell=cell, pbc=pbc) atoms.append(mol) idx_c += n_atoms return atoms @property def velocities(self): """ Convenience property to access molecular velocities instead of the momenta (e.g for power spectra) Returns: torch.Tensor: Velocity tensor with the same shape as the momenta. """ return self.momenta / self.masses @property def kinetic_energy(self) -> torch.tensor: """ Convenience property for computing the kinetic energy associated with each replica and molecule. Returns: torch.Tensor: Tensor of the kinetic energies (in Hartree) with the shape n_replicas x n_molecules x 1 """ kinetic_energy = 0.5 * self.sum_atoms( torch.sum(self.momenta**2, dim=2, keepdim=True) / self.masses ) return kinetic_energy @property def kinetic_energy_tensor(self): """ Compute the kinetic energy tensor (outer product of momenta divided by masses) for pressure computation. The standard kinetic energy is the trace of this tensor. Returns: torch.tensor: n_replicas x n_molecules x 3 x 3 tensor containing kinetic energy components. """ # Apply atom mask kinetic_energy_tensor = 0.5 * self.sum_atoms( self.momenta[..., None] * self.momenta[:, :, None, :] / self.masses[..., None] ) return kinetic_energy_tensor @property def temperature(self): """ Convenience property for accessing the instantaneous temperatures of each replica and molecule. Returns: torch.Tensor: Tensor of the instantaneous temperatures (in Kelvin) with the shape n_replicas x n_molecules x 1 """ temperature = ( 2.0 / (3.0 * self.n_atoms[None, :, None] * spk_units.kB) * self.kinetic_energy ) return temperature @property def potential_energy(self): """ Property for accessing potential energy pf system. The energy array is only populated if a `energy_key` is given in the calculator, energies will be 0 otherwise. Returns: torch.tensor: Potential energy, if requested in the calculator """ return self.energy @potential_energy.setter def potential_energy(self, energy: torch.tensor): """ Setter for the potential energy. Args: energy (torch.tensor): Potential energy. """ self.energy = energy @property def momenta_normal(self): """ Property for normal mode representation of momenta (e.g. for RPMD) Returns: torch.tensor: momenta in normal mode representation. """ return self.nm_transform.beads2normal(self.momenta) @momenta_normal.setter def momenta_normal(self, momenta_normal): """ Use momenta in normal mode representation to set system momenta. Args: momenta_normal (torch.tensor): momenta in normal mode representation """ self.momenta = self.nm_transform.normal2beads(momenta_normal) @property def positions_normal(self): """ Property for normal mode representation of positions (e.g. for RPMD) Returns: torch.tensor: positions in normal mode representation. """ return self.nm_transform.beads2normal(self.positions) @positions_normal.setter def positions_normal(self, positions_normal): """ Use positions in normal mode representation to set system positions. Args: positions_normal (torch.tensor): positions in normal mode representation """ self.positions = self.nm_transform.normal2beads(positions_normal) @property def centroid_positions(self): """ Convenience property to access the positions of the centroid during ring polymer molecular dynamics. Does not make sense during a standard dynamics setup. Returns: torch.Tensor: Tensor of the shape 1 x (n_molecules * n_atoms) x 3 holding the centroid positions. """ return torch.mean(self.positions, dim=0, keepdim=True) @property def centroid_momenta(self): """ Convenience property to access the centroid momenta during ring polymer molecular dynamics. Does not make sense during a standard dynamics setup. Returns: torch.Tensor: Tensor of the shape 1 x (n_molecules * n_atoms) x 3 holding the centroid momenta. """ return torch.mean(self.momenta, dim=0, keepdim=True) @property def centroid_velocities(self): """ Convenience property to access the velocities of the centroid during ring polymer molecular dynamics (e.g. for computing power spectra). Does not make sense during a standard dynamics setup. Returns: torch.Tensor: Tensor of the shape (1 x n_molecules * n_atoms) x 3 holding the centroid velocities. """ return self.centroid_momenta / self.masses @property def centroid_kinetic_energy(self): """ Convenience property for computing the kinetic energy associated with the centroid of each molecule. Only sensible in the context of ring polymer molecular dynamics. Returns: torch.Tensor: Tensor of the centroid kinetic energies (in Hartree) with the shape 1 x n_molecules x 1 """ kinetic_energy = 0.5 * self.sum_atoms( torch.sum(self.centroid_momenta**2, dim=2, keepdim=True) / self.masses ) return kinetic_energy @property def centroid_temperature(self): """ Convenience property for accessing the instantaneous temperatures of the centroid of each molecule. Only makes sense in the context of ring polymer molecular dynamics. Returns: torch.Tensor: Tensor of the instantaneous centroid temperatures ( in Kelvin) with the shape 1 x n_molecules x 1 """ temperature = ( 2.0 / (3.0 * spk_units.kB * self.n_atoms[None, :, None]) * self.centroid_kinetic_energy ) return temperature @property def centroid_potential_energy(self): """ Get the centroid potential energy. Returns: torch.tensor: Centroid potential energy """ return torch.mean(self.energy, dim=0, keepdim=True) @property def volume(self): """ Compute the cell volumes if cells are present. Returns: torch.tensor: n_replicas x n_molecules x 1 containing the volumes. """ volume = torch.sum( self.cells[:, :, 0] * torch.cross(self.cells[:, :, 1], self.cells[:, :, 2], dim=2), dim=2, keepdim=True, ) return volume def compute_pressure(self, tensor: bool = False, kinetic_component: bool = False): """ Compute the pressure (tensor) based on the stress tensor of the systems. Args: tensor (bool): Instead of a scalar pressure, return the full pressure tensor. (Required for anisotropic cell deformation.) kinetic_component (bool): Include the kinetic energy component during the computation of the pressure (default=False). Returns: torch.Tensor: Depending on the tensor-flag, returns a tensor containing the pressure with dimensions n_replicas x n_molecules x 1 (False) or n_replicas x n_molecules x 3 x 3 (True). """ volume = self.volume if torch.any(volume == 0.0): raise SystemError( "Non-zero volume simulation cell required for computation of the instantaneous pressure." ) pressure = -self.stress if tensor: if kinetic_component: pressure += 2 * self.kinetic_energy_tensor / self.volume[..., None] else: pressure = torch.einsum("abii->ab", pressure)[..., None] / 3.0 if kinetic_component: pressure += 2.0 * self.kinetic_energy / self.volume / 3.0 return pressure def compute_centroid_pressure( self, tensor: bool = False, kinetic_component: bool = False ): """ Compute the pressure (tensor) of the ring polymer centroid based on the stress tensor of the systems. Args: tensor (bool): Instead of a scalar pressure, return the full pressure tensor. (Required for anisotropic cell deformation.) kinetic_component (bool): Include the kinetic energy component during the computation of the pressure (default=False). Returns: torch.Tensor: Depending on the tensor-flag, returns a tensor containing the pressure with dimensions n_replicas x n_molecules x 1 (False) or n_replicas x n_molecules x 3 x 3 (True). """ volume = torch.mean(self.volume, dim=0, keepdim=True) if torch.any(volume == 0.0): raise SystemError( "Non-zero volume simulation cell required for computation of the instantaneous pressure." ) # Compute centroid pressure pressure = -torch.mean(self.stress, dim=0, keepdim=True) if tensor: if kinetic_component: pressure += 2 * self._mean_atoms(self.kinetic_energy_tensor) / volume else: pressure = torch.einsum("abii->ab", pressure)[..., None] / 3.0 if kinetic_component: pressure += 2.0 * self.centroid_kinetic_energy / volume / 3.0 return pressure def wrap_positions(self, eps=1e-6): """ Move atoms outside the box back into the box for all dimensions with periodic boundary conditions. Args: eps (float): Small offset for numerical stability """ if torch.any(self.volume == 0.0): raise SystemWarning("Simulation cell required for wrapping of positions.") else: pbc_atomic = self.expand_atoms(self.pbc) # Compute fractional coordinates inverse_cell = torch.inverse(self.cells) inverse_cell = self.expand_atoms(inverse_cell) inv_positions = torch.sum(self.positions[..., None] * inverse_cell, dim=2) # Get periodic coordinates periodic = torch.masked_select(inv_positions, pbc_atomic) # Apply periodic boundary conditions (with small buffer) periodic = periodic + eps periodic = periodic % 1.0 periodic = periodic - eps # Update fractional coordinates inv_positions.masked_scatter_(pbc_atomic, periodic) # Convert to positions self.positions = torch.sum( inv_positions[..., None] * self.expand_atoms(self.cells), dim=2 ) def load_system_state(self, state_dict: OrderedDict[str, torch.Tensor]): """ Routine for restoring the state of a system specified in a previously stored state dict. Used to restart molecular dynamics simulations. Args: state_dict (dict): State dict of the system state. """ self.load_state_dict(state_dict, strict=False) # Set PBC to bool for periodic boundary conditions self.pbc = self.pbc.bool() # Set derived properties for restarting self.n_replicas = self.positions.shape[0] self.total_n_atoms = self.positions.shape[1] self.n_molecules = self.n_atoms.shape[0] # Set normal mode transformer self.nm_transform = self._nm_transformer(self.n_replicas)