import math
import torch
import torch.nn as nn
import schnetpack.nn as snn
from .ops.so3 import generate_clebsch_gordan_rsh, sparsify_clebsch_gordon, sh_indices
from .ops.math import binom
from schnetpack.utils import as_dtype
__all__ = [
"RealSphericalHarmonics",
"SO3TensorProduct",
"SO3Convolution",
"SO3GatedNonlinearity",
"SO3ParametricGatedNonlinearity",
]
[docs]class RealSphericalHarmonics(nn.Module):
"""
Generates the real spherical harmonics for a batch of vectors.
Note:
The vectors passed to this layer are assumed to be normalized to unit length.
Spherical harmonics are generated up to angular momentum `lmax` in dimension 1,
according to the following order:
- l=0, m=0
- l=1, m=-1
- l=1, m=0
- l=1, m=1
- l=2, m=-2
- l=2, m=-1
- etc.
"""
def __init__(self, lmax: int, dtype_str: str = "float32"):
"""
Args:
lmax: maximum angular momentum
dtype_str: dtype for spherical harmonics coefficients
"""
super().__init__()
self.lmax = lmax
(
powers,
zpow,
cAm,
cBm,
cPi,
) = self._generate_Ylm_coefficients(lmax)
dtype = as_dtype(dtype_str)
self.register_buffer("powers", powers.to(dtype=dtype), False)
self.register_buffer("zpow", zpow.to(dtype=dtype), False)
self.register_buffer("cAm", cAm.to(dtype=dtype), False)
self.register_buffer("cBm", cBm.to(dtype=dtype), False)
self.register_buffer("cPi", cPi.to(dtype=dtype), False)
ls = torch.arange(0, lmax + 1)
nls = 2 * ls + 1
self.lidx = torch.repeat_interleave(ls, nls)
self.midx = torch.cat([torch.arange(-l, l + 1) for l in ls])
self.register_buffer("flidx", self.lidx.to(dtype=dtype), False)
def _generate_Ylm_coefficients(self, lmax: int):
# see: https://en.wikipedia.org/wiki/Spherical_harmonics#Real_forms
# calculate Am/Bm coefficients
m = torch.arange(1, lmax + 1, dtype=torch.float64)[:, None]
p = torch.arange(0, lmax + 1, dtype=torch.float64)[None, :]
mask = p <= m
mCp = binom(m, p)
cAm = mCp * torch.cos(0.5 * math.pi * (m - p))
cBm = mCp * torch.sin(0.5 * math.pi * (m - p))
cAm *= mask
cBm *= mask
powers = torch.stack([torch.broadcast_to(p, cAm.shape), m - p], dim=-1)
powers *= mask[:, :, None]
# calculate Pi coefficients
l = torch.arange(0, lmax + 1, dtype=torch.float64)[:, None, None]
m = torch.arange(0, lmax + 1, dtype=torch.float64)[None, :, None]
k = torch.arange(0, lmax // 2 + 1, dtype=torch.float64)[None, None, :]
cPi = torch.sqrt(torch.exp(torch.lgamma(l - m + 1) - torch.lgamma(l + m + 1)))
cPi = cPi * (-1) ** k * 2 ** (-l) * binom(l, k) * binom(2 * l - 2 * k, l)
cPi *= torch.exp(torch.lgamma(l - 2 * k + 1) - torch.lgamma(l - 2 * k - m + 1))
zpow = l - 2 * k - m
# masking of invalid entries
cPi = torch.nan_to_num(cPi, 100.0)
mask1 = k <= torch.floor((l - m) / 2)
mask2 = l >= m
mask = mask1 * mask2
cPi *= mask
zpow *= mask
return powers, zpow, cAm, cBm, cPi
def forward(self, directions: torch.Tensor) -> torch.Tensor:
"""
Args:
directions: batch of unit-length 3D vectors (Nx3)
Returns:
real spherical harmonics up ton angular momentum `lmax`
"""
target_shape = [
directions.shape[0],
self.powers.shape[0],
self.powers.shape[1],
2,
]
Rs = torch.broadcast_to(directions[:, None, None, :2], target_shape)
pows = torch.broadcast_to(self.powers[None], target_shape)
Rs = torch.where(pows == 0, torch.ones_like(Rs), Rs)
temp = Rs**self.powers
monomials_xy = torch.prod(temp, dim=-1)
Am = torch.sum(monomials_xy * self.cAm[None], 2)
Bm = torch.sum(monomials_xy * self.cBm[None], 2)
ABm = torch.cat(
[
torch.flip(Bm, (1,)),
math.sqrt(0.5) * torch.ones((Am.shape[0], 1), device=directions.device),
Am,
],
dim=1,
)
ABm = ABm[:, self.midx + self.lmax]
target_shape = [
directions.shape[0],
self.zpow.shape[0],
self.zpow.shape[1],
self.zpow.shape[2],
]
z = torch.broadcast_to(directions[:, 2, None, None, None], target_shape)
zpows = torch.broadcast_to(self.zpow[None], target_shape)
z = torch.where(zpows == 0, torch.ones_like(z), z)
zk = z**zpows
Pi = torch.sum(zk * self.cPi, dim=-1) # batch x L x M
Pi_lm = Pi[:, self.lidx, abs(self.midx)]
sphharm = torch.sqrt((2 * self.flidx + 1) / (2 * math.pi)) * Pi_lm * ABm
return sphharm
def scalar2rsh(x: torch.Tensor, lmax: int) -> torch.Tensor:
"""
Expand scalar tensor to spherical harmonics shape with angular momentum up to `lmax`
Args:
x: tensor of shape [N, *]
lmax: maximum angular momentum
Returns:
zero-padded tensor to shape [N, (lmax+1)^2, *]
"""
y = torch.cat(
[
x,
torch.zeros(
(x.shape[0], int((lmax + 1) ** 2 - 1), x.shape[2]),
device=x.device,
dtype=x.dtype,
),
],
dim=1,
)
return y
[docs]class SO3TensorProduct(nn.Module):
r"""
SO3-equivariant Clebsch-Gordon tensor product.
With combined indexing s=(l,m), this can be written as:
.. math::
y_{s,f} = \sum_{s_1,s_2} x_{2,s_2,f} x_{1,s_2,f} C_{s_1,s_2}^{s}
"""
def __init__(self, lmax: int):
super().__init__()
self.lmax = lmax
cg = generate_clebsch_gordan_rsh(lmax).to(torch.float32)
cg, idx_in_1, idx_in_2, idx_out = sparsify_clebsch_gordon(cg)
self.register_buffer("idx_in_1", idx_in_1, persistent=False)
self.register_buffer("idx_in_2", idx_in_2, persistent=False)
self.register_buffer("idx_out", idx_out, persistent=False)
self.register_buffer("clebsch_gordan", cg, persistent=False)
def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x1: atom-wise SO3 features, shape: [n_atoms, (l_max+1)^2, n_features]
x2: atom-wise SO3 features, shape: [n_atoms, (l_max+1)^2, n_features]
Returns:
y: product of SO3 features
"""
x1 = x1[:, self.idx_in_1, :]
x2 = x2[:, self.idx_in_2, :]
y = x1 * x2 * self.clebsch_gordan[None, :, None]
y = snn.scatter_add(y, self.idx_out, dim_size=int((self.lmax + 1) ** 2), dim=1)
return y
[docs]class SO3Convolution(nn.Module):
r"""
SO3-equivariant convolution using Clebsch-Gordon tensor product.
With combined indexing s=(l,m), this can be written as:
.. math::
y_{i,s,f} = \sum_{j,s_1,s_2} x_{j,s_2,f} W_{s_1,f}(r_{ij}) Y_{s_1}(r_{ij}) C_{s_1,s_2}^{s}
"""
def __init__(self, lmax: int, n_atom_basis: int, n_radial: int):
super().__init__()
self.lmax = lmax
self.n_atom_basis = n_atom_basis
self.n_radial = n_radial
cg = generate_clebsch_gordan_rsh(lmax).to(torch.float32)
cg, idx_in_1, idx_in_2, idx_out = sparsify_clebsch_gordon(cg)
self.register_buffer("idx_in_1", idx_in_1, persistent=False)
self.register_buffer("idx_in_2", idx_in_2, persistent=False)
self.register_buffer("idx_out", idx_out, persistent=False)
self.register_buffer("clebsch_gordan", cg, persistent=False)
self.filternet = snn.Dense(
n_radial, n_atom_basis * (self.lmax + 1), activation=None
)
lidx, _ = sh_indices(lmax)
self.register_buffer("Widx", lidx[self.idx_in_1])
def _compute_radial_filter(
self, radial_ij: torch.Tensor, cutoff_ij: torch.Tensor
) -> torch.Tensor:
"""
Compute radial (rotationally-invariant) filter
Args:
radial_ij: radial basis functions with shape [n_neighbors, n_radial_basis]
cutoff_ij: cutoff function with shape [n_neighbors, 1]
Returns:
Wij: radial filters with shape [n_neighbors, n_clebsch_gordon, n_features]
"""
Wij = self.filternet(radial_ij) * cutoff_ij
Wij = torch.reshape(Wij, (-1, self.lmax + 1, self.n_atom_basis))
Wij = Wij[:, self.Widx]
return Wij
def forward(
self,
x: torch.Tensor,
radial_ij: torch.Tensor,
dir_ij: torch.Tensor,
cutoff_ij: torch.Tensor,
idx_i: torch.Tensor,
idx_j: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x: atom-wise SO3 features, shape: [n_atoms, (l_max+1)^2, n_atom_basis]
radial_ij: radial basis functions with shape [n_neighbors, n_radial_basis]
dir_ij: direction from atom i to atom j, scaled to unit length
[n_neighbors, 3]
cutoff_ij: cutoff function with shape [n_neighbors, 1]
idx_i: indices for atom i
idx_j: indices for atom j
Returns:
y: convolved SO3 features
"""
xj = x[idx_j[:, None], self.idx_in_2[None, :], :]
Wij = self._compute_radial_filter(radial_ij, cutoff_ij)
v = (
Wij
* dir_ij[:, self.idx_in_1, None]
* self.clebsch_gordan[None, :, None]
* xj
)
yij = snn.scatter_add(
v, self.idx_out, dim_size=int((self.lmax + 1) ** 2), dim=1
)
y = snn.scatter_add(yij, idx_i, dim_size=x.shape[0])
return y
[docs]class SO3ParametricGatedNonlinearity(nn.Module):
r"""
SO3-equivariant parametric gated nonlinearity.
With combined indexing s=(l,m), this can be written as:
.. math::
y_{i,s,f} = x_{j,s,f} * \sigma(f(x_{j,0,\cdot}))
"""
def __init__(self, n_in: int, lmax: int):
super().__init__()
self.lmax = lmax
self.n_in = n_in
self.lidx, _ = sh_indices(lmax)
self.scaling = nn.Linear(n_in, n_in * (lmax + 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
s0 = x[:, 0, :]
h = self.scaling(s0).reshape(-1, self.lmax + 1, self.n_in)
h = h[:, self.lidx]
y = x * torch.sigmoid(h)
return y
[docs]class SO3GatedNonlinearity(nn.Module):
r"""
SO3-equivariant gated nonlinearity.
With combined indexing s=(l,m), this can be written as:
.. math::
y_{i,s,f} = x_{j,s,f} * \sigma(x_{j,0,\cdot})
"""
def __init__(self, lmax: int):
super().__init__()
self.lmax = lmax
self.lidx, _ = sh_indices(lmax)
def forward(self, x: torch.Tensor) -> torch.Tensor:
s0 = x[:, 0, :]
y = x * torch.sigmoid(s0[:, None, :])
return y