Source code for model.base

from __future__ import annotations

from typing import Dict, Optional, List

from schnetpack.transform import Transform
import schnetpack.properties as properties
from schnetpack.utils import as_dtype

import torch
import torch.nn as nn

__all__ = ["AtomisticModel", "NeuralNetworkPotential"]


[docs]class AtomisticModel(nn.Module): """ Base class for all SchNetPack models. SchNetPack models should subclass `AtomisticModel` implement the forward method. To use the automatic collection of required derivatives, each submodule that requires gradients w.r.t to the input, should list them as strings in `submodule.required_derivatives = ["input_key"]`. The model needs to call `self.collect_derivatives()` at the end of its `__init__`. To make use of post-processing transform, the model should call `input = self.postprocess(input)` at the end of its `forward`. The post processors will only be applied if `do_postprocessing=True`. Example: class SimpleModel(AtomisticModel): def __init__( self, representation: nn.Module, output_module: nn.Module, postprocessors: Optional[List[Transform]] = None, input_dtype_str: str = "float32", do_postprocessing: bool = True, ): super().__init__( input_dtype_str=input_dtype_str, postprocessors=postprocessors, do_postprocessing=do_postprocessing, ) self.representation = representation self.output_modules = output_modules self.collect_derivatives() def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: inputs = self.initialize_derivatives(inputs) inputs = self.representation(inputs) inputs = self.output_module(inputs) # apply postprocessing (if enabled) inputs = self.postprocess(inputs) return inputs """ def __init__( self, postprocessors: Optional[List[Transform]] = None, input_dtype_str: str = "float32", do_postprocessing: bool = True, ): """ Args: postprocessors: Post-processing transforms that may be initialized using the `datamodule`, but are not applied during training. input_dtype: The dtype of real inputs as string. do_postprocessing: If true, post-processing is activated. """ super().__init__() self.input_dtype_str = input_dtype_str self.do_postprocessing = do_postprocessing self.postprocessors = nn.ModuleList(postprocessors) self.required_derivatives: Optional[List[str]] = None self.model_outputs: Optional[List[str]] = None def collect_derivatives(self) -> List[str]: self.required_derivatives = None required_derivatives = set() for m in self.modules(): if ( hasattr(m, "required_derivatives") and m.required_derivatives is not None ): required_derivatives.update(m.required_derivatives) required_derivatives: List[str] = list(required_derivatives) self.required_derivatives = required_derivatives def collect_outputs(self) -> List[str]: self.model_outputs = None model_outputs = set() for m in self.modules(): if hasattr(m, "model_outputs") and m.model_outputs is not None: model_outputs.update(m.model_outputs) model_outputs: List[str] = list(model_outputs) self.model_outputs = model_outputs def initialize_derivatives( self, inputs: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: for p in self.required_derivatives: if p in inputs.keys(): inputs[p].requires_grad_() return inputs def initialize_transforms(self, datamodule): for module in self.modules(): if isinstance(module, Transform): module.datamodule(datamodule) def postprocess(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.do_postprocessing: # apply postprocessing for pp in self.postprocessors: inputs = pp(inputs) return inputs def extract_outputs( self, inputs: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: results = {k: inputs[k] for k in self.model_outputs} return results
[docs]class NeuralNetworkPotential(AtomisticModel): """ A generic neural network potential class that sequentially applies a list of input modules, a representation module and a list of output modules. This can be flexibly configured for various, e.g. property prediction or potential energy sufaces with response properties. """ def __init__( self, representation: nn.Module, input_modules: List[nn.Module] = None, output_modules: List[nn.Module] = None, postprocessors: Optional[List[Transform]] = None, input_dtype_str: str = "float32", do_postprocessing: bool = True, ): """ Args: representation: The module that builds representation from inputs. input_modules: Modules that are applied before representation, e.g. to modify input or add additional tensors for response properties. output_modules: Modules that predict output properties from the representation. postprocessors: Post-processing transforms that may be initialized using the `datamodule`, but are not applied during training. input_dtype_str: The dtype of real inputs. do_postprocessing: If true, post-processing is activated. """ super().__init__( input_dtype_str=input_dtype_str, postprocessors=postprocessors, do_postprocessing=do_postprocessing, ) self.representation = representation self.input_modules = nn.ModuleList(input_modules) self.output_modules = nn.ModuleList(output_modules) self.collect_derivatives() self.collect_outputs() def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # initialize derivatives for response properties inputs = self.initialize_derivatives(inputs) for m in self.input_modules: inputs = m(inputs) inputs = self.representation(inputs) for m in self.output_modules: inputs = m(inputs) # apply postprocessing (if enabled) inputs = self.postprocess(inputs) results = self.extract_outputs(inputs) return results