Source code for transform.casting

from typing import Dict
from typing import Dict, Optional
from schnetpack.utils import as_dtype

import torch

from .base import Transform

__all__ = ["CastMap", "CastTo32", "CastTo64"]


[docs]class CastMap(Transform): """ Cast all inputs according to type map. """ is_preprocessor: bool = True is_postprocessor: bool = True def __init__(self, type_map: Dict[str, str]): """ Args: type_map: dict with source_type: target_type (as strings) """ super().__init__() self.type_map = type_map def forward( self, inputs: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: for k, v in inputs.items(): vdtype = str(v.dtype).split(".")[-1] if vdtype in self.type_map: inputs[k] = v.to(dtype=as_dtype(self.type_map[vdtype])) return inputs
[docs]class CastTo32(CastMap): """Cast all float64 tensors to float32""" def __init__(self): super().__init__(type_map={"float64": "float32"})
[docs]class CastTo64(CastMap): """Cast all float32 tensors to float64""" def __init__(self): super().__init__(type_map={"float32": "float64"})