Source code for schnetpack.nn.equivariant

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack.nn as snn
from typing import Tuple

__all__ = ["GatedEquivariantBlock"]


[docs]class GatedEquivariantBlock(nn.Module): """ Gated equivariant block as used for the prediction of tensorial properties by PaiNN. Transforms scalar and vector representation using gated nonlinearities. References: .. [#painn1] Schütt, Unke, Gastegger: Equivariant message passing for the prediction of tensorial properties and molecular spectra. ICML 2021 (to appear) """ def __init__( self, n_sin: int, n_vin: int, n_sout: int, n_vout: int, n_hidden: int, activation=F.silu, sactivation=None, ): """ Args: n_sin: number of input scalar features n_vin: number of input vector features n_sout: number of output scalar features n_vout: number of output vector features n_hidden: number of hidden units activation: interal activation function sactivation: activation function for scalar outputs """ super().__init__() self.n_sin = n_sin self.n_vin = n_vin self.n_sout = n_sout self.n_vout = n_vout self.n_hidden = n_hidden self.mix_vectors = snn.Dense(n_vin, 2 * n_vout, activation=None, bias=False) self.scalar_net = nn.Sequential( snn.Dense(n_sin + n_vout, n_hidden, activation=activation), snn.Dense(n_hidden, n_sout + n_vout, activation=None), ) self.sactivation = sactivation def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]): scalars, vectors = inputs vmix = self.mix_vectors(vectors) vectors_V, vectors_W = torch.split(vmix, self.n_vout, dim=-1) vectors_Vn = torch.norm(vectors_V, dim=-2) ctx = torch.cat([scalars, vectors_Vn], dim=-1) x = self.scalar_net(ctx) s_out, x = torch.split(x, [self.n_sout, self.n_vout], dim=-1) v_out = x.unsqueeze(-2) * vectors_W if self.sactivation: s_out = self.sactivation(s_out) return s_out, v_out