Source code for schnetpack.nn.radial

from math import pi

import torch
import torch.nn as nn

__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"]

from torch import nn as nn


def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
    coeff = -0.5 / torch.pow(widths, 2)
    diff = inputs[..., None] - offsets
    y = torch.exp(coeff * torch.pow(diff, 2))
    return y


[docs]class GaussianRBF(nn.Module): r"""Gaussian radial basis functions.""" def __init__( self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False ): r""" Args: n_rbf: total number of Gaussian functions, :math:`N_g`. cutoff: center of last Gaussian function, :math:`\mu_{N_g}` start: center of first Gaussian function, :math:`\mu_0`. trainable: If True, widths and offset of Gaussian functions are adjusted during training process. """ super(GaussianRBF, self).__init__() self.n_rbf = n_rbf # compute offset and width of Gaussian functions offset = torch.linspace(start, cutoff, n_rbf) widths = torch.FloatTensor( torch.abs(offset[1] - offset[0]) * torch.ones_like(offset) ) if trainable: self.widths = nn.Parameter(widths) self.offsets = nn.Parameter(offset) else: self.register_buffer("widths", widths) self.register_buffer("offsets", offset) def forward(self, inputs: torch.Tensor): return gaussian_rbf(inputs, self.offsets, self.widths)
[docs]class GaussianRBFCentered(nn.Module): r"""Gaussian radial basis functions centered at the origin.""" def __init__( self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False ): r""" Args: n_rbf: total number of Gaussian functions, :math:`N_g`. cutoff: width of last Gaussian function, :math:`\mu_{N_g}` start: width of first Gaussian function, :math:`\mu_0`. trainable: If True, widths of Gaussian functions are adjusted during training process. """ super(GaussianRBFCentered, self).__init__() self.n_rbf = n_rbf # compute offset and width of Gaussian functions widths = torch.linspace(start, cutoff, n_rbf) offset = torch.zeros_like(widths) if trainable: self.widths = nn.Parameter(widths) self.offsets = nn.Parameter(offset) else: self.register_buffer("widths", widths) self.register_buffer("offsets", offset) def forward(self, inputs: torch.Tensor): return gaussian_rbf(inputs, self.offsets, self.widths)
[docs]class BesselRBF(nn.Module): """ Sine for radial basis functions with coulomb decay (0th order bessel). References: .. [#dimenet] Klicpera, Groß, Günnemann: Directional message passing for molecular graphs. ICLR 2020 """ def __init__(self, n_rbf: int, cutoff: float): """ Args: cutoff: radial cutoff n_rbf: number of basis functions. """ super(BesselRBF, self).__init__() self.n_rbf = n_rbf freqs = torch.arange(1, n_rbf + 1) * pi / cutoff self.register_buffer("freqs", freqs) def forward(self, inputs): ax = inputs[..., None] * self.freqs sinax = torch.sin(ax) norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs) y = sinax / norm[..., None] return y