import torch
import torch.nn as nn
import torch.nn.functional as F
import schnetpack.nn as snn
from typing import Tuple
__all__ = ["GatedEquivariantBlock"]
[docs]class GatedEquivariantBlock(nn.Module):
"""
Gated equivariant block as used for the prediction of tensorial properties by PaiNN.
Transforms scalar and vector representation using gated nonlinearities.
References:
.. [#painn1] Schütt, Unke, Gastegger:
Equivariant message passing for the prediction of tensorial properties and molecular spectra.
ICML 2021 (to appear)
"""
def __init__(
self,
n_sin: int,
n_vin: int,
n_sout: int,
n_vout: int,
n_hidden: int,
activation=F.silu,
sactivation=None,
):
"""
Args:
n_sin: number of input scalar features
n_vin: number of input vector features
n_sout: number of output scalar features
n_vout: number of output vector features
n_hidden: number of hidden units
activation: interal activation function
sactivation: activation function for scalar outputs
"""
super().__init__()
self.n_sin = n_sin
self.n_vin = n_vin
self.n_sout = n_sout
self.n_vout = n_vout
self.n_hidden = n_hidden
self.mix_vectors = snn.Dense(n_vin, 2 * n_vout, activation=None, bias=False)
self.scalar_net = nn.Sequential(
snn.Dense(n_sin + n_vout, n_hidden, activation=activation),
snn.Dense(n_hidden, n_sout + n_vout, activation=None),
)
self.sactivation = sactivation
def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]):
scalars, vectors = inputs
vmix = self.mix_vectors(vectors)
vectors_V, vectors_W = torch.split(vmix, self.n_vout, dim=-1)
vectors_Vn = torch.norm(vectors_V, dim=-2)
ctx = torch.cat([scalars, vectors_Vn], dim=-1)
x = self.scalar_net(ctx)
s_out, x = torch.split(x, [self.n_sout, self.n_vout], dim=-1)
v_out = x.unsqueeze(-2) * vectors_W
if self.sactivation:
s_out = self.sactivation(s_out)
return s_out, v_out