Source code for md.calculators.schnetpack_calculator

from __future__ import annotations
from typing import Union, List, Dict, TYPE_CHECKING

import schnetpack.atomistic.response
from schnetpack.utils import load_model

if TYPE_CHECKING:
    from schnetpack.md import System
    from schnetpack.model import AtomisticModel
    from schnetpack.md.neighborlist_md import NeighborListMD

import torch
import logging

from schnetpack.md.calculators.base_calculator import MDCalculator
from schnetpack.md.calculators.ensemble_calculator import EnsembleCalculator
from schnetpack.md.utils import activate_model_stress

log = logging.getLogger(__name__)

__all__ = ["SchNetPackCalculator", "SchNetPackEnsembleCalculator"]


[docs]class SchNetPackCalculator(MDCalculator): """ MD calculator for schnetpack models. Args: model_file (str): Path to stored schnetpack model. force_key (str): String indicating the entry corresponding to the molecular forces energy_unit (float, float): Conversion factor converting the energies returned by the used model back to internal MD units. position_unit (float, float): Conversion factor for converting the system positions to the units required by the model. neighbor_list (schnetpack.md.neighbor_list.MDNeighborList): Neighbor list object for determining which interatomic distances should be computed. energy_key (str, optional): If provided, label is used to store the energies returned by the model to the system. stress_key (str, optional): If provided, label is used to store the stress returned by the model to the system (required for constant pressure simulations). required_properties (list): List of properties to be computed by the calculator property_conversion (dict(float)): Optional dictionary of conversion factors for other properties predicted by the model. Only changes the units used for logging the various outputs. script_model (bool): convert loaded model to torchscript. """ def __init__( self, model_file: str, force_key: str, energy_unit: Union[str, float], position_unit: Union[str, float], neighbor_list: NeighborListMD, energy_key: str = None, stress_key: str = None, required_properties: List = [], property_conversion: Dict[str, Union[str, float]] = {}, script_model: bool = False, ): super(SchNetPackCalculator, self).__init__( required_properties=required_properties, force_key=force_key, energy_unit=energy_unit, position_unit=position_unit, energy_key=energy_key, stress_key=stress_key, property_conversion=property_conversion, ) self.script_model = script_model self.model = self._prepare_model(model_file) self.neighbor_list = neighbor_list def _prepare_model(self, model_file: str) -> AtomisticModel: """ Load an individual model. Args: model_file (str): path to model. Returns: AtomisticTask: loaded schnetpack model """ return self._load_model(model_file) def _load_model(self, model_file: str) -> AtomisticModel: """ Load an individual model, activate stress computation and convert to torch script if requested. Args: model_file (str): path to model. Returns: AtomisticTask: loaded schnetpack model """ log.info("Loading model from {:s}".format(model_file)) # load model and keep it on CPU, device can be changed afterwards model = load_model(model_file, device=torch.device("cpu")).to(torch.float64) model = model.eval() if self.stress_key is not None: log.info("Activating stress computation...") model = activate_model_stress(model, self.stress_key) if self.script_model: log.info("Converting model to torch script...") model = torch.jit.script(model) log.info("Deactivating inference mode for simulation...") self._deactivate_postprocessing(model) return model @staticmethod def _deactivate_postprocessing(model: AtomisticModel) -> AtomisticModel: if hasattr(model, "postprocessors"): for pp in model.postprocessors: if isinstance(pp, schnetpack.transform.AddOffsets): log.info("Found `AddOffsets` postprocessing module...") log.info( "Constant offset of {:20.11f} per atom will be removed...".format( pp.mean.detach().cpu().numpy() ) ) model.do_postprocessing = False return model def calculate(self, system: System): """ Main routine, generates a properly formatted input for the schnetpack model from the system, performs the computation and uses the results to update the system state. Args: system (schnetpack.md.System): System object containing current state of the simulation. """ inputs = self._generate_input(system) self.results = self.model(inputs) self._update_system(system) def _generate_input(self, system: System) -> Dict[str, torch.Tensor]: """ Function to extracts neighbor lists, atom_types, positions e.t.c. from the system and generate a properly formatted input for the schnetpack model. Args: system (schnetpack.md.System): System object containing current state of the simulation. Returns: dict(torch.Tensor): Schnetpack inputs in dictionary format. """ inputs = self._get_system_molecules(system) neighbors = self.neighbor_list.get_neighbors(inputs) inputs.update(neighbors) return inputs
[docs]class SchNetPackEnsembleCalculator(EnsembleCalculator, SchNetPackCalculator): """ Ensemble calculator using schnetpack models. Uncertainties are computed as the variance of all model predictions. """ def __init__( self, model_files: List[str], force_key: str, energy_unit: Union[str, float], position_unit: Union[str, float], neighbor_list: NeighborListMD, energy_key: str = None, stress_key: str = None, required_properties: List = [], property_conversion: Dict[str, Union[str, float]] = {}, script_model: bool = True, ): """ Args: model_files (list(str)): List of paths to stored schnetpack model to be used in ensemble. force_key (str): String indicating the entry corresponding to the molecular forces energy_unit (float, float): Conversion factor converting the energies returned by the used model back to internal MD units. position_unit (float, float): Conversion factor for converting the system positions to the units required by the model. neighbor_list (schnetpack.md.neighbor_list.MDNeighborList): Neighbor list object for determining which interatomic distances should be computed. energy_key (str, optional): If provided, label is used to store the energies returned by the model to the system. stress_key (str, optional): If provided, label is used to store the stress returned by the model to the system (required for constant pressure simulations). required_properties (list): List of properties to be computed by the calculator property_conversion (dict(float)): Optional dictionary of conversion factors for other properties predicted by the model. Only changes the units used for logging the various outputs. script_model (bool): convert loaded model to torchscript. """ super(SchNetPackEnsembleCalculator, self).__init__( model_file=model_files, required_properties=required_properties, force_key=force_key, energy_unit=energy_unit, position_unit=position_unit, neighbor_list=neighbor_list, energy_key=energy_key, stress_key=stress_key, property_conversion=property_conversion, script_model=script_model, ) # Update the required properties self._update_required_properties() # Convert list of models to module list self.models = torch.nn.ModuleList(self.model) def _prepare_model(self, model_files: List[str]) -> List[AtomisticModel]: """ Load multiple models. Args: model_files (list(str)): List of stored models. Returns: list(AtomisticModel): list of loaded models. """ return [self._load_model(model_file) for model_file in model_files]