Source code for transform.base

from typing import Optional, Dict

import torch
import torch.nn as nn

import schnetpack as spk

__all__ = [
    "Transform",
    "TransformException",
]


class TransformException(Exception):
    pass


[docs]class Transform(nn.Module): """ Base class for all transforms. The base class ensures that the reference to the data and datamodule attributes are initialized. Transforms can be used as pre- or post-processing layers. They can also be used for other parts of a model, that need to be initialized based on data. To implement a new transform, override the forward method. Preprocessors are applied to single examples, while postprocessors operate on batches. All transforms should return a modified `inputs` dictionary. """ def datamodule(self, value): """ Extract all required information from data module automatically when using PyTorch Lightning integration. The transform should also implement a way to set these things manually, to make it usable independent of PL. Do not store the datamodule, as this does not work with torchscript conversion! """ pass def forward( self, inputs: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: raise NotImplementedError def teardown(self): pass