model.AtomisticModel

class model.AtomisticModel(*args: Any, **kwargs: Any)[source]

Base class for all SchNetPack models.

SchNetPack models should subclass AtomisticModel implement the forward method. To use the automatic collection of required derivatives, each submodule that requires gradients w.r.t to the input, should list them as strings in submodule.required_derivatives = [“input_key”]. The model needs to call self.collect_derivatives() at the end of its __init__.

To make use of post-processing transform, the model should call input = self.postprocess(input) at the end of its forward. The post processors will only be applied if do_postprocessing=True.

Example

class SimpleModel(AtomisticModel):
def __init__(

self, representation: nn.Module, output_module: nn.Module, postprocessors: Optional[List[Transform]] = None, input_dtype_str: str = “float32”, do_postprocessing: bool = True,

):
super().__init__(

input_dtype_str=input_dtype_str, postprocessors=postprocessors, do_postprocessing=do_postprocessing,

) self.representation = representation self.output_modules = output_modules

self.collect_derivatives()

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

inputs = self.initialize_derivatives(inputs)

inputs = self.representation(inputs) inputs = self.output_module(inputs)

# apply postprocessing (if enabled) inputs = self.postprocess(inputs) return inputs

Parameters:
  • postprocessors – Post-processing transforms that may be initialized using the datamodule, but are not applied during training.

  • input_dtype – The dtype of real inputs as string.

  • do_postprocessing – If true, post-processing is activated.