Source code for atomistic.atomwise

from typing import Sequence, Union, Callable, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack as spk
import schnetpack.nn as snn
import as properties

__all__ = ["Atomwise", "DipoleMoment", "Polarizability"]

[docs]class Atomwise(nn.Module): """ Predicts atom-wise contributions and accumulates global prediction, e.g. for the energy. If `aggregation_mode` is None, only the per-atom predictions will be returned. """ def __init__( self, n_in: int, n_out: int = 1, n_hidden: Optional[Union[int, Sequence[int]]] = None, n_layers: int = 2, activation: Callable = F.silu, aggregation_mode: str = "sum", output_key: str = "y", per_atom_output_key: Optional[str] = None, ): """ Args: n_in: input dimension of representation n_out: output dimension of target property (default: 1) n_hidden: size of hidden layers. If an integer, same number of node is used for all hidden layers resulting in a rectangular network. If None, the number of neurons is divided by two after each layer starting n_in resulting in a pyramidal network. n_layers: number of layers. aggregation_mode: one of {sum, avg} (default: sum) output_key: the key under which the result will be stored per_atom_output_key: If not None, the key under which the per-atom result will be stored """ super(Atomwise, self).__init__() self.output_key = output_key self.model_outputs = [output_key] self.per_atom_output_key = per_atom_output_key if self.per_atom_output_key is not None: self.model_outputs.append(self.per_atom_output_key) self.n_out = n_out if aggregation_mode is None and self.per_atom_output_key is None: raise ValueError( "If `aggregation_mode` is None, `per_atom_output_key` needs to be set," + " since no accumulated output will be returned!" ) self.outnet = spk.nn.build_mlp( n_in=n_in, n_out=n_out, n_hidden=n_hidden, n_layers=n_layers, activation=activation, ) self.aggregation_mode = aggregation_mode def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # predict atomwise contributions y = self.outnet(inputs["scalar_representation"]) # accumulate the per-atom output if necessary if self.per_atom_output_key is not None: inputs[self.per_atom_output_key] = y # aggregate if self.aggregation_mode is not None: idx_m = inputs[properties.idx_m] maxm = int(idx_m[-1]) + 1 y = snn.scatter_add(y, idx_m, dim_size=maxm) y = torch.squeeze(y, -1) if self.aggregation_mode == "avg": y = y / inputs[properties.n_atoms] inputs[self.output_key] = y return inputs
[docs]class DipoleMoment(nn.Module): """ Predicts dipole moments from latent partial charges and (optionally) local, atomic dipoles. The latter requires a representation supplying (equivariant) vector features. References: .. [#painn1] Schütt, Unke, Gastegger. Equivariant message passing for the prediction of tensorial properties and molecular spectra. ICML 2021, .. [#irspec] Gastegger, Behler, Marquetand. Machine learning molecular dynamics for the simulation of infrared spectra. Chemical science 8.10 (2017): 6924-6935. .. [#dipole] Veit et al. Predicting molecular dipole moments by combining atomic partial charges and atomic dipoles. The Journal of Chemical Physics 153.2 (2020): 024113. """ def __init__( self, n_in: int, n_hidden: Optional[Union[int, Sequence[int]]] = None, n_layers: int = 2, activation: Callable = F.silu, predict_magnitude: bool = False, return_charges: bool = False, dipole_key: str = properties.dipole_moment, charges_key: str = properties.partial_charges, correct_charges: bool = True, use_vector_representation: bool = False, ): """ Args: n_in: input dimension of representation n_hidden: size of hidden layers. If an integer, same number of node is used for all hidden layers resulting in a rectangular network. If None, the number of neurons is divided by two after each layer starting n_in resulting in a pyramidal network. n_layers: number of layers. activation: activation function predict_magnitude: If true, calculate magnitude of dipole return_charges: If true, return latent partial charges dipole_key: the key under which the dipoles will be stored charges_key: the key under which partial charges will be stored correct_charges: If true, forces the sum of partial charges to be the total charge, if provided, and zero otherwise. use_vector_representation: If true, use vector representation to predict local, atomic dipoles. """ super().__init__() self.dipole_key = dipole_key self.charges_key = charges_key self.return_charges = return_charges self.model_outputs = [dipole_key] if self.return_charges: self.model_outputs.append(charges_key) self.predict_magnitude = predict_magnitude self.use_vector_representation = use_vector_representation self.correct_charges = correct_charges if use_vector_representation: self.outnet = spk.nn.build_gated_equivariant_mlp( n_in=n_in, n_out=1, n_hidden=n_hidden, n_layers=n_layers, activation=activation, sactivation=activation, ) else: self.outnet = spk.nn.build_mlp( n_in=n_in, n_out=1, n_hidden=n_hidden, n_layers=n_layers, activation=activation, ) def forward(self, inputs): positions = inputs[properties.R] l0 = inputs["scalar_representation"] natoms = inputs[properties.n_atoms] idx_m = inputs[properties.idx_m] maxm = int(idx_m[-1]) + 1 if self.use_vector_representation: l1 = inputs["vector_representation"] charges, atomic_dipoles = self.outnet((l0, l1)) atomic_dipoles = torch.squeeze(atomic_dipoles, -1) else: charges = self.outnet(l0) atomic_dipoles = 0.0 if self.correct_charges: sum_charge = snn.scatter_add(charges, idx_m, dim_size=maxm) if properties.total_charge in inputs: total_charge = inputs[properties.total_charge][:, None] else: total_charge = torch.zeros_like(sum_charge) charge_correction = (total_charge - sum_charge) / natoms.unsqueeze(-1) charge_correction = charge_correction[idx_m] charges = charges + charge_correction if self.return_charges: inputs[self.charges_key] = charges y = positions * charges if self.use_vector_representation: y = y + atomic_dipoles # sum over atoms y = snn.scatter_add(y, idx_m, dim_size=maxm) if self.predict_magnitude: y = torch.norm(y, dim=1, keepdim=False) inputs[self.dipole_key] = y return inputs
[docs]class Polarizability(nn.Module): """ Predicts polarizability tensor using tensor rank factorization. This requires an equivariant representation, e.g. PaiNN, that provides both scalar and vectorial features. References: .. [#painn1a] Schütt, Unke, Gastegger: Equivariant message passing for the prediction of tensorial properties and molecular spectra. ICML 2021, """ def __init__( self, n_in: int, n_hidden: Optional[Union[int, Sequence[int]]] = None, n_layers: int = 2, activation: Callable = F.silu, polarizability_key: str = properties.polarizability, ): """ Args: n_in: input dimension of representation n_hidden: size of hidden layers. If an integer, same number of node is used for all hidden layers resulting in a rectangular network. If None, the number of neurons is divided by two after each layer starting n_in resulting in a pyramidal network. n_layers: number of layers. activation: activation function polarizability_key: the key under which the predicted polarizability will be stored """ super(Polarizability, self).__init__() self.n_in = n_in self.n_layers = n_layers self.n_hidden = n_hidden self.polarizability_key = polarizability_key self.model_outputs = [polarizability_key] self.outnet = spk.nn.build_gated_equivariant_mlp( n_in=n_in, n_out=1, n_hidden=n_hidden, n_layers=n_layers, activation=activation, sactivation=activation, ) self.requires_dr = False self.requires_stress = False def forward(self, inputs): positions = inputs[properties.R] l0 = inputs["scalar_representation"] l1 = inputs["vector_representation"] dim = l1.shape[-2] l0, l1 = self.outnet((l0, l1)) # isotropic on diagonal alpha = l0[..., 0:1] size = list(alpha.shape) size[-1] = dim alpha = alpha.expand(*size) alpha = torch.diag_embed(alpha) # add anisotropic components mur = l1[..., None, 0] * positions[..., None, :] alpha_c = mur + mur.transpose(-2, -1) alpha = alpha + alpha_c # sum over atoms idx_m = inputs[properties.idx_m] maxm = int(idx_m[-1]) + 1 alpha = snn.scatter_add(alpha, idx_m, dim_size=maxm) inputs[self.polarizability_key] = alpha return inputs