Source code for representation.painn

from typing import Callable, Dict, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack.properties as properties
import schnetpack.nn as snn
from schnetpack.nn.embedding import NuclearEmbedding, ElectronicEmbedding

__all__ = ["PaiNN", "PaiNNInteraction", "PaiNNMixing"]


class PaiNNInteraction(nn.Module):
    r"""PaiNN interaction block for modeling equivariant interactions of atomistic systems."""

    def __init__(self, n_atom_basis: int, activation: Callable):
        """
        Args:
            n_atom_basis: number of features to describe atomic environments.
            activation: if None, no activation function is used.
        """
        super(PaiNNInteraction, self).__init__()
        self.n_atom_basis = n_atom_basis

        self.interatomic_context_net = nn.Sequential(
            snn.Dense(n_atom_basis, n_atom_basis, activation=activation),
            snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
        )

    def forward(
        self,
        q: torch.Tensor,
        mu: torch.Tensor,
        Wij: torch.Tensor,
        dir_ij: torch.Tensor,
        idx_i: torch.Tensor,
        idx_j: torch.Tensor,
        n_atoms: int,
    ):
        """Compute interaction output.

        Args:
            q: scalar input values
            mu: vector input values
            Wij: filter
            idx_i: index of center atom i
            idx_j: index of neighbors j

        Returns:
            atom features after interaction
        """
        # inter-atomic
        x = self.interatomic_context_net(q)
        xj = x[idx_j]
        muj = mu[idx_j]
        x = Wij * xj

        dq, dmuR, dmumu = torch.split(x, self.n_atom_basis, dim=-1)
        dq = snn.scatter_add(dq, idx_i, dim_size=n_atoms)
        dmu = dmuR * dir_ij[..., None] + dmumu * muj
        dmu = snn.scatter_add(dmu, idx_i, dim_size=n_atoms)

        q = q + dq
        mu = mu + dmu

        return q, mu


class PaiNNMixing(nn.Module):
    r"""PaiNN interaction block for mixing on atom features."""

    def __init__(self, n_atom_basis: int, activation: Callable, epsilon: float = 1e-8):
        """
        Args:
            n_atom_basis: number of features to describe atomic environments.
            activation: if None, no activation function is used.
            epsilon: stability constant added in norm to prevent numerical instabilities
        """
        super(PaiNNMixing, self).__init__()
        self.n_atom_basis = n_atom_basis

        self.intraatomic_context_net = nn.Sequential(
            snn.Dense(2 * n_atom_basis, n_atom_basis, activation=activation),
            snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
        )
        self.mu_channel_mix = snn.Dense(
            n_atom_basis, 2 * n_atom_basis, activation=None, bias=False
        )
        self.epsilon = epsilon

    def forward(self, q: torch.Tensor, mu: torch.Tensor):
        """Compute intraatomic mixing.

        Args:
            q: scalar input values
            mu: vector input values

        Returns:
            atom features after interaction
        """
        ## intra-atomic
        mu_mix = self.mu_channel_mix(mu)
        mu_V, mu_W = torch.split(mu_mix, self.n_atom_basis, dim=-1)
        mu_Vn = torch.sqrt(torch.sum(mu_V**2, dim=-2, keepdim=True) + self.epsilon)

        ctx = torch.cat([q, mu_Vn], dim=-1)
        x = self.intraatomic_context_net(ctx)

        dq_intra, dmu_intra, dqmu_intra = torch.split(x, self.n_atom_basis, dim=-1)
        dmu_intra = dmu_intra * mu_W

        dqmu_intra = dqmu_intra * torch.sum(mu_V * mu_W, dim=1, keepdim=True)

        q = q + dq_intra + dqmu_intra
        mu = mu + dmu_intra
        return q, mu


[docs]class PaiNN(nn.Module): """PaiNN - polarizable interaction neural network References: .. [#painn1] Schütt, Unke, Gastegger: Equivariant message passing for the prediction of tensorial properties and molecular spectra. ICML 2021, http://proceedings.mlr.press/v139/schutt21a.html """ def __init__( self, n_atom_basis: int, n_interactions: int, radial_basis: nn.Module, cutoff_fn: Optional[Callable] = None, activation: Optional[Callable] = F.silu, max_z: int = 101, shared_interactions: bool = False, shared_filters: bool = False, epsilon: float = 1e-8, activate_charge_spin_embedding: bool = False, embedding: Union[Callable, nn.Module] = None, ): """ Args: n_atom_basis: number of features to describe atomic environments. This determines the size of each embedding vector; i.e. embeddings_dim. n_interactions: number of interaction blocks. radial_basis: layer for expanding interatomic distances in a basis set cutoff_fn: cutoff function max_z: maximal nuclear charge activation: activation function shared_interactions: if True, share the weights across interaction blocks. shared_interactions: if True, share the weights across filter-generating networks. epsilon: stability constant added in norm to prevent numerical instabilities activate_charge_spin_embedding: if True, charge and spin embeddings are added to nuclear embeddings taken from SpookyNet Implementation embedding: custom nuclear embedding """ super(PaiNN, self).__init__() self.n_atom_basis = n_atom_basis self.n_interactions = n_interactions self.cutoff_fn = cutoff_fn self.cutoff = cutoff_fn.cutoff self.radial_basis = radial_basis self.activate_charge_spin_embedding = activate_charge_spin_embedding # initialize nuclear embedding self.embedding = embedding if self.embedding is None: self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0) # initialize spin and charge embeddings if self.activate_charge_spin_embedding: self.charge_embedding = ElectronicEmbedding( self.n_atom_basis, num_residual=1, activation=activation, is_charged=True) self.spin_embedding = ElectronicEmbedding( self.n_atom_basis, num_residual=1, activation=activation, is_charged=False) # initialize filter layers self.share_filters = shared_filters if shared_filters: self.filter_net = snn.Dense( self.radial_basis.n_rbf, 3 * n_atom_basis, activation=None ) else: self.filter_net = snn.Dense( self.radial_basis.n_rbf, self.n_interactions * n_atom_basis * 3, activation=None, ) # initialize interaction blocks self.interactions = snn.replicate_module( lambda: PaiNNInteraction( n_atom_basis=self.n_atom_basis, activation=activation ), self.n_interactions, shared_interactions, ) self.mixing = snn.replicate_module( lambda: PaiNNMixing( n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon ), self.n_interactions, shared_interactions, ) def forward(self, inputs: Dict[str, torch.Tensor]): """ Compute atomic representations/embeddings. Args: inputs: SchNetPack dictionary of input tensors. Returns: torch.Tensor: atom-wise representation. list of torch.Tensor: intermediate atom-wise representations, if return_intermediate=True was used. """ # get tensors from input dictionary atomic_numbers = inputs[properties.Z] r_ij = inputs[properties.Rij] idx_i = inputs[properties.idx_i] idx_j = inputs[properties.idx_j] n_atoms = atomic_numbers.shape[0] # compute atom and pair features d_ij = torch.norm(r_ij, dim=1, keepdim=True) dir_ij = r_ij / d_ij phi_ij = self.radial_basis(d_ij) fcut = self.cutoff_fn(d_ij) filters = self.filter_net(phi_ij) * fcut[..., None] if self.share_filters: filter_list = [filters] * self.n_interactions else: filter_list = torch.split(filters, 3 * self.n_atom_basis, dim=-1) # compute initial embeddings q = self.embedding(atomic_numbers)[:, None] # add spin and charge embeddings if hasattr(self, "activate_charge_spin_embedding") and self.activate_charge_spin_embedding: # get tensors from input dictionary total_charge = inputs[properties.total_charge] spin = inputs[properties.spin_multiplicity] num_batch = len(inputs[properties.idx]) idx_m = inputs[properties.idx_m] charge_embedding = self.charge_embedding( q.squeeze(), total_charge, num_batch, idx_m )[:, None] spin_embedding = self.spin_embedding( q.squeeze(), spin, num_batch, idx_m )[:, None] # additive combining of nuclear, charge and spin embedding q = (q + charge_embedding + spin_embedding) # compute interaction blocks and update atomic embeddings qs = q.shape mu = torch.zeros((qs[0], 3, qs[2]), device=q.device) for i, (interaction, mixing) in enumerate(zip(self.interactions, self.mixing)): q, mu = interaction(q, mu, filter_list[i], dir_ij, idx_i, idx_j, n_atoms) q, mu = mixing(q, mu) q = q.squeeze(1) # collect results inputs["scalar_representation"] = q inputs["vector_representation"] = mu return inputs