Source code for schnetpack.nn.utils

from typing import Callable, Optional

import torch
from torch import nn as nn

__all__ = ["replicate_module", "derivative_from_atomic", "derivative_from_molecular"]

from torch.autograd import grad


[docs]def replicate_module( module_factory: Callable[[], nn.Module], n: int, share_params: bool ): if share_params: module_list = nn.ModuleList([module_factory()] * n) else: module_list = nn.ModuleList([module_factory() for i in range(n)]) return module_list
def derivative_from_molecular( fx: torch.Tensor, dx: torch.Tensor, create_graph: bool = False, retain_graph: bool = False, ): """ Compute the derivative of `fx` with respect to `dx` if the leading dimension of `fx` is the number of molecules (e.g. energies, dipole moments, etc). Args: fx (torch.Tensor): Tensor for which the derivative is taken. dx (torch.Tensor): Derivative. create_graph (bool): Create computational graph. retain_graph (bool): Keep the computational graph. Returns: torch.Tensor: derivative of `fx` with respect to `dx`. """ fx_shape = fx.shape dx_shape = dx.shape # Final shape takes into consideration whether derivative will yield atomic or molecular properties final_shape = (dx_shape[0], *fx_shape[1:], *dx_shape[1:]) fx = fx.view(fx_shape[0], -1) dfdx = torch.stack( [ grad( fx[..., i], dx, torch.ones_like(fx[..., i]), create_graph=create_graph, retain_graph=retain_graph, )[0] for i in range(fx.shape[1]) ], dim=1, ) dfdx = dfdx.view(final_shape) return dfdx def derivative_from_atomic( fx: torch.Tensor, dx: torch.Tensor, n_atoms: torch.Tensor, create_graph: bool = False, retain_graph: bool = False, ): """ Compute the derivative of a tensor with the leading dimension of (batch x atoms) with respect to another tensor of either dimension (batch * atoms) (e.g. R) or (batch * atom pairs) (e.g. Rij). This function is primarily used for computing Hessians and Hessian-like response properties (e.g. nuclear spin-spin couplings). The final tensor will have the shape ( batch * atoms * atoms x ....). This is quite inefficient, use with care. Args: fx (torch.Tensor): Tensor for which the derivative is taken. dx (torch.Tensor): Derivative. n_atoms (torch.Tensor): Tensor containing the number of atoms for each molecule. create_graph (bool): Create computational graph. retain_graph (bool): Keep the computational graph. Returns: torch.Tensor: derivative of `fx` with respect to `dx`. """ # Split input tensor for easier bookkeeping fxm = fx.split(list(n_atoms)) dfdx = [] n_mol = 0 # Compute all derivatives for idx in range(len(fxm)): fx = fxm[idx].view(-1) # Generate the individual derivatives dfdx_mol = [] for i in range(fx.shape[0]): dfdx_i = grad( fx[i], dx, torch.ones_like(fx[i]), create_graph=create_graph, retain_graph=retain_graph, )[0] dfdx_mol.append(dfdx_i[n_mol : n_mol + n_atoms[idx], ...]) # Build molecular matrix and reshape dfdx_mol = torch.stack(dfdx_mol, dim=0) dfdx_mol = dfdx_mol.view(n_atoms[idx], 3, n_atoms[idx], 3) dfdx_mol = dfdx_mol.permute(0, 2, 1, 3) dfdx_mol = dfdx_mol.reshape(n_atoms[idx] ** 2, 3, 3) dfdx.append(dfdx_mol) n_mol += n_atoms[idx] # Accumulate everything dfdx = torch.cat(dfdx, dim=0) return dfdx