from typing import Sequence, Union, Callable, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import schnetpack as spk
import schnetpack.nn as snn
import schnetpack.properties as properties
__all__ = ["Atomwise", "DipoleMoment", "Polarizability"]
[docs]class Atomwise(nn.Module):
"""
Predicts atom-wise contributions and accumulates global prediction, e.g. for the energy.
If `aggregation_mode` is None, only the per-atom predictions will be returned.
"""
def __init__(
self,
n_in: int,
n_out: int = 1,
n_hidden: Optional[Union[int, Sequence[int]]] = None,
n_layers: int = 2,
activation: Callable = F.silu,
aggregation_mode: str = "sum",
output_key: str = "y",
per_atom_output_key: Optional[str] = None,
):
"""
Args:
n_in: input dimension of representation
n_out: output dimension of target property (default: 1)
n_hidden: size of hidden layers.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
n_layers: number of layers.
aggregation_mode: one of {sum, avg} (default: sum)
output_key: the key under which the result will be stored
per_atom_output_key: If not None, the key under which the per-atom result will be stored
"""
super(Atomwise, self).__init__()
self.output_key = output_key
self.model_outputs = [output_key]
self.per_atom_output_key = per_atom_output_key
if self.per_atom_output_key is not None:
self.model_outputs.append(self.per_atom_output_key)
self.n_out = n_out
if aggregation_mode is None and self.per_atom_output_key is None:
raise ValueError(
"If `aggregation_mode` is None, `per_atom_output_key` needs to be set,"
+ " since no accumulated output will be returned!"
)
self.outnet = spk.nn.build_mlp(
n_in=n_in,
n_out=n_out,
n_hidden=n_hidden,
n_layers=n_layers,
activation=activation,
)
self.aggregation_mode = aggregation_mode
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# predict atomwise contributions
y = self.outnet(inputs["scalar_representation"])
# accumulate the per-atom output if necessary
if self.per_atom_output_key is not None:
inputs[self.per_atom_output_key] = y
# aggregate
if self.aggregation_mode is not None:
idx_m = inputs[properties.idx_m]
maxm = int(idx_m[-1]) + 1
y = snn.scatter_add(y, idx_m, dim_size=maxm)
y = torch.squeeze(y, -1)
if self.aggregation_mode == "avg":
y = y / inputs[properties.n_atoms]
inputs[self.output_key] = y
return inputs
[docs]class DipoleMoment(nn.Module):
"""
Predicts dipole moments from latent partial charges and (optionally) local, atomic dipoles.
The latter requires a representation supplying (equivariant) vector features.
References:
.. [#painn1] Schütt, Unke, Gastegger.
Equivariant message passing for the prediction of tensorial properties and molecular spectra.
ICML 2021, http://proceedings.mlr.press/v139/schutt21a.html
.. [#irspec] Gastegger, Behler, Marquetand.
Machine learning molecular dynamics for the simulation of infrared spectra.
Chemical science 8.10 (2017): 6924-6935.
.. [#dipole] Veit et al.
Predicting molecular dipole moments by combining atomic partial charges and atomic dipoles.
The Journal of Chemical Physics 153.2 (2020): 024113.
"""
def __init__(
self,
n_in: int,
n_hidden: Optional[Union[int, Sequence[int]]] = None,
n_layers: int = 2,
activation: Callable = F.silu,
predict_magnitude: bool = False,
return_charges: bool = False,
dipole_key: str = properties.dipole_moment,
charges_key: str = properties.partial_charges,
correct_charges: bool = True,
use_vector_representation: bool = False,
):
"""
Args:
n_in: input dimension of representation
n_hidden: size of hidden layers.
If an integer, same number of node is used for all hidden layers
resulting in a rectangular network.
If None, the number of neurons is divided by two after each layer
starting n_in resulting in a pyramidal network.
n_layers: number of layers.
activation: activation function
predict_magnitude: If true, calculate magnitude of dipole
return_charges: If true, return latent partial charges
dipole_key: the key under which the dipoles will be stored
charges_key: the key under which partial charges will be stored
correct_charges: If true, forces the sum of partial charges to be the total
charge, if provided, and zero otherwise.
use_vector_representation: If true, use vector representation to predict
local, atomic dipoles.
"""
super().__init__()
self.dipole_key = dipole_key
self.charges_key = charges_key
self.return_charges = return_charges
self.model_outputs = [dipole_key]
if self.return_charges:
self.model_outputs.append(charges_key)
self.predict_magnitude = predict_magnitude
self.use_vector_representation = use_vector_representation
self.correct_charges = correct_charges
if use_vector_representation:
self.outnet = spk.nn.build_gated_equivariant_mlp(
n_in=n_in,
n_out=1,
n_hidden=n_hidden,
n_layers=n_layers,
activation=activation,
sactivation=activation,
)
else:
self.outnet = spk.nn.build_mlp(
n_in=n_in,
n_out=1,
n_hidden=n_hidden,
n_layers=n_layers,
activation=activation,
)
def forward(self, inputs):
positions = inputs[properties.R]
l0 = inputs["scalar_representation"]
natoms = inputs[properties.n_atoms]
idx_m = inputs[properties.idx_m]
maxm = int(idx_m[-1]) + 1
if self.use_vector_representation:
l1 = inputs["vector_representation"]
charges, atomic_dipoles = self.outnet((l0, l1))
atomic_dipoles = torch.squeeze(atomic_dipoles, -1)
else:
charges = self.outnet(l0)
atomic_dipoles = 0.0
if self.correct_charges:
sum_charge = snn.scatter_add(charges, idx_m, dim_size=maxm)
if properties.total_charge in inputs:
total_charge = inputs[properties.total_charge][:, None]
else:
total_charge = torch.zeros_like(sum_charge)
charge_correction = (total_charge - sum_charge) / natoms.unsqueeze(-1)
charge_correction = charge_correction[idx_m]
charges = charges + charge_correction
if self.return_charges:
inputs[self.charges_key] = charges
y = positions * charges
if self.use_vector_representation:
y = y + atomic_dipoles
# sum over atoms
y = snn.scatter_add(y, idx_m, dim_size=maxm)
if self.predict_magnitude:
y = torch.norm(y, dim=1, keepdim=False)
inputs[self.dipole_key] = y
return inputs
[docs]class Polarizability(nn.Module):
"""
Predicts polarizability tensor using tensor rank factorization.
This requires an equivariant representation, e.g. PaiNN, that provides both scalar and vectorial features.
References:
.. [#painn1a] Schütt, Unke, Gastegger:
Equivariant message passing for the prediction of tensorial properties and molecular spectra.
ICML 2021, http://proceedings.mlr.press/v139/schutt21a.html
"""
def __init__(
self,
n_in: int,
n_hidden: Optional[Union[int, Sequence[int]]] = None,
n_layers: int = 2,
activation: Callable = F.silu,
polarizability_key: str = properties.polarizability,
):
"""
Args:
n_in: input dimension of representation
n_hidden: size of hidden layers.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
n_layers: number of layers.
activation: activation function
polarizability_key: the key under which the predicted polarizability will be stored
"""
super(Polarizability, self).__init__()
self.n_in = n_in
self.n_layers = n_layers
self.n_hidden = n_hidden
self.polarizability_key = polarizability_key
self.model_outputs = [polarizability_key]
self.outnet = spk.nn.build_gated_equivariant_mlp(
n_in=n_in,
n_out=1,
n_hidden=n_hidden,
n_layers=n_layers,
activation=activation,
sactivation=activation,
)
self.requires_dr = False
self.requires_stress = False
def forward(self, inputs):
positions = inputs[properties.R]
l0 = inputs["scalar_representation"]
l1 = inputs["vector_representation"]
dim = l1.shape[-2]
l0, l1 = self.outnet((l0, l1))
# isotropic on diagonal
alpha = l0[..., 0:1]
size = list(alpha.shape)
size[-1] = dim
alpha = alpha.expand(*size)
alpha = torch.diag_embed(alpha)
# add anisotropic components
mur = l1[..., None, 0] * positions[..., None, :]
alpha_c = mur + mur.transpose(-2, -1)
alpha = alpha + alpha_c
# sum over atoms
idx_m = inputs[properties.idx_m]
maxm = int(idx_m[-1]) + 1
alpha = snn.scatter_add(alpha, idx_m, dim_size=maxm)
inputs[self.polarizability_key] = alpha
return inputs