from typing import Callable, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import schnetpack.properties as properties
import schnetpack.nn as snn
from schnetpack.nn.embedding import NuclearEmbedding, ElectronicEmbedding
__all__ = ["PaiNN", "PaiNNInteraction", "PaiNNMixing"]
class PaiNNInteraction(nn.Module):
r"""PaiNN interaction block for modeling equivariant interactions of atomistic systems."""
def __init__(self, n_atom_basis: int, activation: Callable):
"""
Args:
n_atom_basis: number of features to describe atomic environments.
activation: if None, no activation function is used.
"""
super(PaiNNInteraction, self).__init__()
self.n_atom_basis = n_atom_basis
self.interatomic_context_net = nn.Sequential(
snn.Dense(n_atom_basis, n_atom_basis, activation=activation),
snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
)
def forward(
self,
q: torch.Tensor,
mu: torch.Tensor,
Wij: torch.Tensor,
dir_ij: torch.Tensor,
idx_i: torch.Tensor,
idx_j: torch.Tensor,
n_atoms: int,
):
"""Compute interaction output.
Args:
q: scalar input values
mu: vector input values
Wij: filter
idx_i: index of center atom i
idx_j: index of neighbors j
Returns:
atom features after interaction
"""
# inter-atomic
x = self.interatomic_context_net(q)
xj = x[idx_j]
muj = mu[idx_j]
x = Wij * xj
dq, dmuR, dmumu = torch.split(x, self.n_atom_basis, dim=-1)
dq = snn.scatter_add(dq, idx_i, dim_size=n_atoms)
dmu = dmuR * dir_ij[..., None] + dmumu * muj
dmu = snn.scatter_add(dmu, idx_i, dim_size=n_atoms)
q = q + dq
mu = mu + dmu
return q, mu
class PaiNNMixing(nn.Module):
r"""PaiNN interaction block for mixing on atom features."""
def __init__(self, n_atom_basis: int, activation: Callable, epsilon: float = 1e-8):
"""
Args:
n_atom_basis: number of features to describe atomic environments.
activation: if None, no activation function is used.
epsilon: stability constant added in norm to prevent numerical instabilities
"""
super(PaiNNMixing, self).__init__()
self.n_atom_basis = n_atom_basis
self.intraatomic_context_net = nn.Sequential(
snn.Dense(2 * n_atom_basis, n_atom_basis, activation=activation),
snn.Dense(n_atom_basis, 3 * n_atom_basis, activation=None),
)
self.mu_channel_mix = snn.Dense(
n_atom_basis, 2 * n_atom_basis, activation=None, bias=False
)
self.epsilon = epsilon
def forward(self, q: torch.Tensor, mu: torch.Tensor):
"""Compute intraatomic mixing.
Args:
q: scalar input values
mu: vector input values
Returns:
atom features after interaction
"""
## intra-atomic
mu_mix = self.mu_channel_mix(mu)
mu_V, mu_W = torch.split(mu_mix, self.n_atom_basis, dim=-1)
mu_Vn = torch.sqrt(torch.sum(mu_V**2, dim=-2, keepdim=True) + self.epsilon)
ctx = torch.cat([q, mu_Vn], dim=-1)
x = self.intraatomic_context_net(ctx)
dq_intra, dmu_intra, dqmu_intra = torch.split(x, self.n_atom_basis, dim=-1)
dmu_intra = dmu_intra * mu_W
dqmu_intra = dqmu_intra * torch.sum(mu_V * mu_W, dim=1, keepdim=True)
q = q + dq_intra + dqmu_intra
mu = mu + dmu_intra
return q, mu
[docs]class PaiNN(nn.Module):
"""PaiNN - polarizable interaction neural network
References:
.. [#painn1] Schütt, Unke, Gastegger:
Equivariant message passing for the prediction of tensorial properties and molecular spectra.
ICML 2021, http://proceedings.mlr.press/v139/schutt21a.html
"""
def __init__(
self,
n_atom_basis: int,
n_interactions: int,
radial_basis: nn.Module,
cutoff_fn: Optional[Callable] = None,
activation: Optional[Callable] = F.silu,
max_z: int = 101,
shared_interactions: bool = False,
shared_filters: bool = False,
epsilon: float = 1e-8,
activate_charge_spin_embedding: bool = False,
embedding: Union[Callable, nn.Module] = None,
):
"""
Args:
n_atom_basis: number of features to describe atomic environments.
This determines the size of each embedding vector; i.e. embeddings_dim.
n_interactions: number of interaction blocks.
radial_basis: layer for expanding interatomic distances in a basis set
cutoff_fn: cutoff function
max_z: maximal nuclear charge
activation: activation function
shared_interactions: if True, share the weights across
interaction blocks.
shared_interactions: if True, share the weights across
filter-generating networks.
epsilon: stability constant added in norm to prevent numerical instabilities
activate_charge_spin_embedding: if True, charge and spin embeddings are added
to nuclear embeddings taken from SpookyNet Implementation
embedding: custom nuclear embedding
"""
super(PaiNN, self).__init__()
self.n_atom_basis = n_atom_basis
self.n_interactions = n_interactions
self.cutoff_fn = cutoff_fn
self.cutoff = cutoff_fn.cutoff
self.radial_basis = radial_basis
self.activate_charge_spin_embedding = activate_charge_spin_embedding
# initialize nuclear embedding
self.embedding = embedding
if self.embedding is None:
self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0)
# initialize spin and charge embeddings
if self.activate_charge_spin_embedding:
self.charge_embedding = ElectronicEmbedding(
self.n_atom_basis,
num_residual=1,
activation=activation,
is_charged=True)
self.spin_embedding = ElectronicEmbedding(
self.n_atom_basis,
num_residual=1,
activation=activation,
is_charged=False)
# initialize filter layers
self.share_filters = shared_filters
if shared_filters:
self.filter_net = snn.Dense(
self.radial_basis.n_rbf, 3 * n_atom_basis, activation=None
)
else:
self.filter_net = snn.Dense(
self.radial_basis.n_rbf,
self.n_interactions * n_atom_basis * 3,
activation=None,
)
# initialize interaction blocks
self.interactions = snn.replicate_module(
lambda: PaiNNInteraction(
n_atom_basis=self.n_atom_basis, activation=activation
),
self.n_interactions,
shared_interactions,
)
self.mixing = snn.replicate_module(
lambda: PaiNNMixing(
n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon
),
self.n_interactions,
shared_interactions,
)
def forward(self, inputs: Dict[str, torch.Tensor]):
"""
Compute atomic representations/embeddings.
Args:
inputs: SchNetPack dictionary of input tensors.
Returns:
torch.Tensor: atom-wise representation.
list of torch.Tensor: intermediate atom-wise representations, if
return_intermediate=True was used.
"""
# get tensors from input dictionary
atomic_numbers = inputs[properties.Z]
r_ij = inputs[properties.Rij]
idx_i = inputs[properties.idx_i]
idx_j = inputs[properties.idx_j]
n_atoms = atomic_numbers.shape[0]
# compute atom and pair features
d_ij = torch.norm(r_ij, dim=1, keepdim=True)
dir_ij = r_ij / d_ij
phi_ij = self.radial_basis(d_ij)
fcut = self.cutoff_fn(d_ij)
filters = self.filter_net(phi_ij) * fcut[..., None]
if self.share_filters:
filter_list = [filters] * self.n_interactions
else:
filter_list = torch.split(filters, 3 * self.n_atom_basis, dim=-1)
# compute initial embeddings
q = self.embedding(atomic_numbers)[:, None]
# add spin and charge embeddings
if hasattr(self, "activate_charge_spin_embedding") and self.activate_charge_spin_embedding:
# get tensors from input dictionary
total_charge = inputs[properties.total_charge]
spin = inputs[properties.spin_multiplicity]
num_batch = len(inputs[properties.idx])
idx_m = inputs[properties.idx_m]
charge_embedding = self.charge_embedding(
q.squeeze(),
total_charge,
num_batch,
idx_m
)[:, None]
spin_embedding = self.spin_embedding(
q.squeeze(), spin, num_batch, idx_m
)[:, None]
# additive combining of nuclear, charge and spin embedding
q = (q + charge_embedding + spin_embedding)
# compute interaction blocks and update atomic embeddings
qs = q.shape
mu = torch.zeros((qs[0], 3, qs[2]), device=q.device)
for i, (interaction, mixing) in enumerate(zip(self.interactions, self.mixing)):
q, mu = interaction(q, mu, filter_list[i], dir_ij, idx_i, idx_j, n_atoms)
q, mu = mixing(q, mu)
q = q.squeeze(1)
# collect results
inputs["scalar_representation"] = q
inputs["vector_representation"] = mu
return inputs