from math import pi
import torch
import torch.nn as nn
__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"]
from torch import nn as nn
def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
coeff = -0.5 / torch.pow(widths, 2)
diff = inputs[..., None] - offsets
y = torch.exp(coeff * torch.pow(diff, 2))
return y
[docs]class GaussianRBF(nn.Module):
r"""Gaussian radial basis functions."""
def __init__(
self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False
):
r"""
Args:
n_rbf: total number of Gaussian functions, :math:`N_g`.
cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
start: center of first Gaussian function, :math:`\mu_0`.
trainable: If True, widths and offset of Gaussian functions
are adjusted during training process.
"""
super(GaussianRBF, self).__init__()
self.n_rbf = n_rbf
# compute offset and width of Gaussian functions
offset = torch.linspace(start, cutoff, n_rbf)
widths = torch.FloatTensor(
torch.abs(offset[1] - offset[0]) * torch.ones_like(offset)
)
if trainable:
self.widths = nn.Parameter(widths)
self.offsets = nn.Parameter(offset)
else:
self.register_buffer("widths", widths)
self.register_buffer("offsets", offset)
def forward(self, inputs: torch.Tensor):
return gaussian_rbf(inputs, self.offsets, self.widths)
[docs]class GaussianRBFCentered(nn.Module):
r"""Gaussian radial basis functions centered at the origin."""
def __init__(
self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False
):
r"""
Args:
n_rbf: total number of Gaussian functions, :math:`N_g`.
cutoff: width of last Gaussian function, :math:`\mu_{N_g}`
start: width of first Gaussian function, :math:`\mu_0`.
trainable: If True, widths of Gaussian functions
are adjusted during training process.
"""
super(GaussianRBFCentered, self).__init__()
self.n_rbf = n_rbf
# compute offset and width of Gaussian functions
widths = torch.linspace(start, cutoff, n_rbf)
offset = torch.zeros_like(widths)
if trainable:
self.widths = nn.Parameter(widths)
self.offsets = nn.Parameter(offset)
else:
self.register_buffer("widths", widths)
self.register_buffer("offsets", offset)
def forward(self, inputs: torch.Tensor):
return gaussian_rbf(inputs, self.offsets, self.widths)
[docs]class BesselRBF(nn.Module):
"""
Sine for radial basis functions with coulomb decay (0th order bessel).
References:
.. [#dimenet] Klicpera, Groß, Günnemann:
Directional message passing for molecular graphs.
ICLR 2020
"""
def __init__(self, n_rbf: int, cutoff: float):
"""
Args:
cutoff: radial cutoff
n_rbf: number of basis functions.
"""
super(BesselRBF, self).__init__()
self.n_rbf = n_rbf
freqs = torch.arange(1, n_rbf + 1) * pi / cutoff
self.register_buffer("freqs", freqs)
def forward(self, inputs):
ax = inputs[..., None] * self.freqs
sinax = torch.sin(ax)
norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
y = sinax / norm[..., None]
return y