Source code for schnetpack.nn.cutoff

import math
import torch
from torch import nn

__all__ = [
    "CosineCutoff",
    "MollifierCutoff",
    "mollifier_cutoff",
    "cosine_cutoff",
    "SwitchFunction",
]


def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor):
    r""" Behler-style cosine cutoff.

        .. math::
           f(r) = \begin{cases}
            0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
              & r < r_\text{cutoff} \\
            0 & r \geqslant r_\text{cutoff} \\
            \end{cases}

        Args:
            cutoff (float, optional): cutoff radius.

        """

    # Compute values of cutoff function
    input_cut = 0.5 * (torch.cos(input * math.pi / cutoff) + 1.0)
    # Remove contributions beyond the cutoff radius
    input_cut *= (input < cutoff).float()
    return input_cut


[docs]class CosineCutoff(nn.Module): r""" Behler-style cosine cutoff module. .. math:: f(r) = \begin{cases} 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases} """ def __init__(self, cutoff: float): """ Args: cutoff (float, optional): cutoff radius. """ super(CosineCutoff, self).__init__() self.register_buffer("cutoff", torch.FloatTensor([cutoff])) def forward(self, input: torch.Tensor): return cosine_cutoff(input, self.cutoff)
def mollifier_cutoff(input: torch.Tensor, cutoff: torch.Tensor, eps: torch.Tensor): r""" Mollifier cutoff scaled to have a value of 1 at :math:`r=0`. .. math:: f(r) = \begin{cases} \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases} Args: cutoff: Cutoff radius. eps: Offset added to distances for numerical stability. """ mask = (input + eps < cutoff).float() exponent = 1.0 - 1.0 / (1.0 - torch.pow(input * mask / cutoff, 2)) cutoffs = torch.exp(exponent) cutoffs = cutoffs * mask return cutoffs
[docs]class MollifierCutoff(nn.Module): r""" Mollifier cutoff module scaled to have a value of 1 at :math:`r=0`. .. math:: f(r) = \begin{cases} \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases} """ def __init__(self, cutoff: float, eps: float = 1.0e-7): """ Args: cutoff: Cutoff radius. eps: Offset added to distances for numerical stability. """ super(MollifierCutoff, self).__init__() self.register_buffer("cutoff", torch.FloatTensor([cutoff])) self.register_buffer("eps", torch.FloatTensor([eps])) def forward(self, input: torch.Tensor): return mollifier_cutoff(input, self.cutoff, self.eps)
def _switch_component( x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor ) -> torch.Tensor: """ Basic component of switching functions. Args: x (torch.Tensor): Switch functions. ones (torch.Tensor): Tensor with ones. zeros (torch.Tensor): Zero tensor Returns: torch.Tensor: Output tensor. """ x_ = torch.where(x <= 0, ones, x) return torch.where(x <= 0, zeros, torch.exp(-ones / x_)) class SwitchFunction(nn.Module): """ Decays from 1 to 0 between `switch_on` and `switch_off`. """ def __init__(self, switch_on: float, switch_off: float): """ Args: switch_on (float): Onset of switch. switch_off (float): Value from which on switch is 0. """ super(SwitchFunction, self).__init__() self.register_buffer("switch_on", torch.Tensor([switch_on])) self.register_buffer("switch_off", torch.Tensor([switch_off])) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): tensor to which switching function should be applied to. Returns: torch.Tensor: switch output """ x = (x - self.switch_on) / (self.switch_off - self.switch_on) ones = torch.ones_like(x) zeros = torch.zeros_like(x) fp = _switch_component(x, ones, zeros) fm = _switch_component(1 - x, ones, zeros) f_switch = torch.where(x <= 0, ones, torch.where(x >= 1, zeros, fm / (fp + fm))) return f_switch