from typing import Dict, Optional
import torch
from ase.data import atomic_masses
import schnetpack.properties as structure
from .base import Transform
from schnetpack.nn import scatter_add
__all__ = [
"SubtractCenterOfMass",
"SubtractCenterOfGeometry",
"AddOffsets",
"RemoveOffsets",
"ScaleProperty",
]
[docs]class SubtractCenterOfMass(Transform):
"""
Subtract center of mass from positions.
"""
is_preprocessor: bool = True
is_postprocessor: bool = False
def __init__(self):
super().__init__()
def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
masses = torch.tensor(atomic_masses[inputs[structure.Z]])
inputs[structure.position] -= (
masses.unsqueeze(-1) * inputs[structure.position]
).sum(0) / masses.sum()
return inputs
[docs]class SubtractCenterOfGeometry(Transform):
"""
Subtract center of geometry from positions.
"""
is_preprocessor: bool = True
is_postprocessor: bool = False
def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
inputs[structure.position] -= inputs[structure.position].mean(0)
return inputs
[docs]class RemoveOffsets(Transform):
"""
Remove offsets from property based on the mean of the training data and/or the
single atom reference calculations.
The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule,
when it is used. Otherwise, they have to be provided in the init manually.
"""
is_preprocessor: bool = True
is_postprocessor: bool = True
def __init__(
self,
property,
remove_mean: bool = False,
remove_atomrefs: bool = False,
is_extensive: bool = True,
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
property: The property to add the offsets to.
remove_mean: If true, remove mean of the dataset from property.
remove_atomrefs: If true, remove single-atom references.
is_extensive: Set true if the property is extensive.
zmax: Set the maximum atomic number, to determine the size of the atomref
tensor.
atomrefs: Provide single-atom references directly.
property_mean: Provide mean property value / n_atoms.
"""
super().__init__()
self._property = property
self.remove_mean = remove_mean
self.remove_atomrefs = remove_atomrefs
self.is_extensive = is_extensive
self.estimate_atomref = estimate_atomref
assert not (
estimate_atomref and atomrefs is not None
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"
if atomrefs is not None:
self._atomrefs_initialized = True
else:
self._atomrefs_initialized = False
if property_mean is not None:
self._mean_initialized = True
else:
self._mean_initialized = False
if self.remove_atomrefs:
atomrefs = atomrefs or torch.zeros((zmax,))
self.register_buffer("atomref", atomrefs)
if self.remove_mean:
property_mean = property_mean or torch.zeros((1,))
self.register_buffer("mean", property_mean)
def datamodule(self, _datamodule):
"""
Sets mean and atomref automatically when using PyTorchLightning integration.
"""
if self.remove_atomrefs and not self._atomrefs_initialized:
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(
property=self._property, is_extensive=self.is_extensive
)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()
if self.remove_mean and not self._mean_initialized:
stats = _datamodule.get_stats(
self._property, self.is_extensive, self.remove_atomrefs
)
self.mean = stats[0].detach()
def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
if self.remove_mean:
mean = (
self.mean * inputs[structure.n_atoms]
if self.is_extensive
else self.mean
)
inputs[self._property] -= mean
if self.remove_atomrefs:
atomref_bias = torch.sum(self.atomref[inputs[structure.Z]])
if not self.is_extensive:
atomref_bias /= inputs[structure.n_atoms].item()
inputs[self._property] -= atomref_bias
return inputs
class ScaleProperty(Transform):
"""
Scale an entry of the input or results dioctionary.
The `scale` can be automatically obtained from the AtomsDataModule,
when it is used. Otherwise, it has to be provided in the init manually.
"""
is_preprocessor: bool = True
is_postprocessor: bool = True
def __init__(
self,
input_key: str,
target_key: str = None,
output_key: str = None,
scale_by_mean: bool = False,
scale: torch.Tensor = None,
):
"""
Args:
input_key: dict key of input to be scaled
target_key: dict key of target to derive scaling from
(either its mean or std dev)
output_key: dict key for scaled output
scale_by_mean: if true, use the mean of the target variable for scaling,
otherwise use its standard deviation
scale: provide the scale of the property manually.
"""
super().__init__()
self.input_key = input_key
self._target_key = target_key or input_key
self.output_key = output_key or input_key
self._scale_by_mean = scale_by_mean
self.model_outputs = [self.output_key]
if scale is not None:
self._initialized = True
else:
self._initialized = False
scale = scale or torch.ones((1,))
self.register_buffer("scale", scale)
def datamodule(self, _datamodule):
if not self._initialized:
stats = _datamodule.get_stats(self._target_key, True, False)
scale = stats[0] if self._scale_by_mean else stats[1]
self.scale = torch.abs(scale).detach()
def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
inputs[self.output_key] = inputs[self.input_key] * self.scale
return inputs
[docs]class AddOffsets(Transform):
"""
Add offsets to property based on the mean of the training data and/or the single
atom reference calculations.
The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule,
when it is used. Otherwise, they have to be provided in the init manually.
Hint:
Place this postprocessor after casting to float64 for higher numerical
precision.
"""
is_preprocessor: bool = True
is_postprocessor: bool = True
atomref: torch.Tensor
def __init__(
self,
property,
add_mean: bool = False,
add_atomrefs: bool = False,
is_extensive: bool = True,
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
property: The property to add the offsets to.
add_mean: If true, add mean of the dataset.
add_atomrefs: If true, add single-atom references.
is_extensive: Set true if the property is extensive.
zmax: Set the maximum atomic number, to determine the size of the atomref
tensor.
atomrefs: Provide single-atom references directly.
property_mean: Provide mean property value / n_atoms.
"""
super().__init__()
self._property = property
self.add_mean = add_mean
self.add_atomrefs = add_atomrefs
self.is_extensive = is_extensive
self._aggregation = "sum" if self.is_extensive else "mean"
self.estimate_atomref = estimate_atomref
assert not (
estimate_atomref and atomrefs is not None
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"
if atomrefs is not None:
self._atomrefs_initialized = True
else:
self._atomrefs_initialized = False
if property_mean is not None:
self._mean_initialized = True
else:
self._mean_initialized = False
atomrefs = atomrefs or torch.zeros((zmax,))
property_mean = property_mean or torch.zeros((1,))
self.register_buffer("atomref", atomrefs)
self.register_buffer("mean", property_mean)
def datamodule(self, _datamodule):
if self.add_atomrefs and not self._atomrefs_initialized:
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(
property=self._property, is_extensive=self.is_extensive
)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()
if self.add_mean and not self._mean_initialized:
stats = _datamodule.get_stats(
self._property, self.is_extensive, self.add_atomrefs
)
self.mean = stats[0].detach()
def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
if self.add_mean:
mean = (
self.mean * inputs[structure.n_atoms]
if self.is_extensive
else self.mean
)
inputs[self._property] += mean
if self.add_atomrefs:
idx_m = inputs[structure.idx_m]
y0i = self.atomref[inputs[structure.Z]]
maxm = int(idx_m[-1]) + 1
y0 = scatter_add(y0i, idx_m, dim_size=maxm)
if not self.is_extensive:
y0 /= inputs[structure.n_atoms]
inputs[self._property] += y0
return inputs