Source code for representation.schnet

from typing import Callable, Dict, Union

import torch
from torch import nn

import schnetpack.properties as structure
from schnetpack.nn import Dense, scatter_add
from schnetpack.nn.embedding import NuclearEmbedding
from schnetpack.nn import ElectronicEmbedding
from schnetpack.nn.activations import shifted_softplus

import schnetpack.nn as snn

__all__ = ["SchNet", "SchNetInteraction"]


class SchNetInteraction(nn.Module):
    r"""SchNet interaction block for modeling interactions of atomistic systems."""

    def __init__(
        self,
        n_atom_basis: int,
        n_rbf: int,
        n_filters: int,
        activation: Callable = shifted_softplus,
    ):
        """
        Args:
            n_atom_basis: number of features to describe atomic environments.
            n_rbf (int): number of radial basis functions.
            n_filters: number of filters used in continuous-filter convolution.
            activation: if None, no activation function is used.
        """
        super(SchNetInteraction, self).__init__()
        self.in2f = Dense(n_atom_basis, n_filters, bias=False, activation=None)
        self.f2out = nn.Sequential(
            Dense(n_filters, n_atom_basis, activation=activation),
            Dense(n_atom_basis, n_atom_basis, activation=None),
        )
        self.filter_network = nn.Sequential(
            Dense(n_rbf, n_filters, activation=activation), Dense(n_filters, n_filters)
        )

    def forward(
        self,
        x: torch.Tensor,
        f_ij: torch.Tensor,
        idx_i: torch.Tensor,
        idx_j: torch.Tensor,
        rcut_ij: torch.Tensor,
    ):
        """Compute interaction output.

        Args:
            x: input values
            Wij: filter
            idx_i: index of center atom i
            idx_j: index of neighbors j

        Returns:
            atom features after interaction
        """
        x = self.in2f(x)
        Wij = self.filter_network(f_ij)
        Wij = Wij * rcut_ij[:, None]

        # continuous-filter convolution
        x_j = x[idx_j]
        x_ij = x_j * Wij
        x = scatter_add(x_ij, idx_i, dim_size=x.shape[0])

        x = self.f2out(x)
        return x



[docs]class SchNet(nn.Module): """SchNet architecture for learning representations of atomistic systems References: .. [#schnet1] Schütt, Arbabzadah, Chmiela, Müller, Tkatchenko: Quantum-chemical insights from deep tensor neural networks. Nature Communications, 8, 13890. 2017. .. [#schnet_transfer] Schütt, Kindermans, Sauceda, Chmiela, Tkatchenko, Müller: SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. In Advances in Neural Information Processing Systems, pp. 992-1002. 2017. .. [#schnet3] Schütt, Sauceda, Kindermans, Tkatchenko, Müller: SchNet - a deep learning architecture for molceules and materials. The Journal of Chemical Physics 148 (24), 241722. 2018. """ def __init__( self, n_atom_basis: int, n_interactions: int, radial_basis: nn.Module, cutoff_fn: Callable, n_filters: int = None, shared_interactions: bool = False, max_z: int = 101, activation: Union[Callable, nn.Module] = shifted_softplus, activate_charge_spin_embedding: bool = False, embedding: Union[Callable, nn.Module] = None, ): """ Args: n_atom_basis: number of features to describe atomic environments. This determines the size of each embedding vector; i.e. embeddings_dim. n_interactions: number of interaction blocks. radial_basis: layer for expanding interatomic distances in a basis set cutoff_fn: cutoff function n_filters: number of filters used in continuous-filter convolution shared_interactions: if True, share the weights across interaction blocks and filter-generating networks. max_z: maximal nuclear charge activation: activation function activate_charge_spin_embedding: if True, charge and spin embeddings are added to nuclear embeddings taken from SpookyNet Implementation embedding: type of nuclear embedding to use (simple is simple embedding and complex is the one with electron configuration) """ super().__init__() self.n_atom_basis = n_atom_basis self.n_filters = n_filters or self.n_atom_basis self.radial_basis = radial_basis self.cutoff_fn = cutoff_fn self.cutoff = cutoff_fn.cutoff self.activate_charge_spin_embedding = activate_charge_spin_embedding # initialize nuclear embedding self.embedding = embedding if self.embedding is None: self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0) # initialize spin and charge embeddings if self.activate_charge_spin_embedding: self.charge_embedding = ElectronicEmbedding( self.n_atom_basis, num_residual=1, activation=activation, is_charged=True) self.spin_embedding = ElectronicEmbedding( self.n_atom_basis, num_residual=1, activation=activation, is_charged=False) # initialize interaction blocks self.interactions = snn.replicate_module( lambda: SchNetInteraction( n_atom_basis=self.n_atom_basis, n_rbf=self.radial_basis.n_rbf, n_filters=self.n_filters, activation=activation, ), n_interactions, shared_interactions, ) def forward(self, inputs: Dict[str, torch.Tensor]): # get tensors from input dictionary atomic_numbers = inputs[structure.Z] r_ij = inputs[structure.Rij] idx_i = inputs[structure.idx_i] idx_j = inputs[structure.idx_j] # compute pair features d_ij = torch.norm(r_ij, dim=1) f_ij = self.radial_basis(d_ij) rcut_ij = self.cutoff_fn(d_ij) # compute initial embeddings x = self.embedding(atomic_numbers) # add spin and charge embeddings if hasattr(self, "activate_charge_spin_embedding") and self.activate_charge_spin_embedding: # get tensors from input dictionary total_charge = inputs[structure.total_charge] spin = inputs[structure.spin_multiplicity] idx_m = inputs[structure.idx_m] num_batch = len(inputs[structure.idx]) charge_embedding = self.charge_embedding( x, total_charge, num_batch, idx_m ) spin_embedding = self.spin_embedding( x, spin, num_batch, idx_m ) # additive combining of nuclear, charge and spin embedding x = x + charge_embedding + spin_embedding # compute interaction blocks and update atomic embeddings for interaction in self.interactions: v = interaction(x, f_ij, idx_i, idx_j, rcut_ij) x = x + v # collect results inputs["scalar_representation"] = x return inputs