from typing import Dict, Optional, List, Tuple
import torch
import torch.nn as nn
from torch.autograd import grad
from schnetpack.nn.utils import derivative_from_molecular, derivative_from_atomic
import schnetpack.properties as properties
__all__ = ["Forces", "Strain", "Response"]
class ResponseException(Exception):
pass
[docs]class Forces(nn.Module):
"""
Predicts forces and stress as response of the energy prediction
w.r.t. the atom positions and strain.
"""
def __init__(
self,
calc_forces: bool = True,
calc_stress: bool = False,
energy_key: str = properties.energy,
force_key: str = properties.forces,
stress_key: str = properties.stress,
):
"""
Args:
calc_forces: If True, calculate atomic forces.
calc_stress: If True, calculate the stress tensor.
energy_key: Key of the energy in results.
force_key: Key of the forces in results.
stress_key: Key of the stress in results.
"""
super(Forces, self).__init__()
self.calc_forces = calc_forces
self.calc_stress = calc_stress
self.energy_key = energy_key
self.force_key = force_key
self.stress_key = stress_key
self.model_outputs = []
if calc_forces:
self.model_outputs.append(force_key)
if calc_stress:
self.model_outputs.append(stress_key)
self.required_derivatives = []
if self.calc_forces:
self.required_derivatives.append(properties.R)
if self.calc_stress:
self.required_derivatives.append(properties.strain)
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
Epred = inputs[self.energy_key]
go: List[Optional[torch.Tensor]] = [torch.ones_like(Epred)]
grads = grad(
[Epred],
[inputs[prop] for prop in self.required_derivatives],
grad_outputs=go,
create_graph=self.training,
)
if self.calc_forces:
dEdR = grads[0]
# TorchScript needs Tensor instead of Optional[Tensor]
if dEdR is None:
dEdR = torch.zeros_like(inputs[properties.R])
inputs[self.force_key] = -dEdR
if self.calc_stress:
stress = grads[-1]
# TorchScript needs Tensor instead of Optional[Tensor]
if stress is None:
stress = torch.zeros_like(inputs[properties.cell])
cell = inputs[properties.cell]
volume = torch.sum(
cell[:, 0, :] * torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
dim=1,
keepdim=True,
)[:, :, None]
inputs[self.stress_key] = stress / volume
return inputs
class Response(nn.Module):
implemented_properties = [
properties.forces,
properties.stress,
properties.hessian,
properties.dipole_moment,
properties.polarizability,
properties.dipole_derivatives,
properties.partial_charges,
properties.polarizability_derivatives,
properties.shielding,
properties.nuclear_spin_coupling,
]
def __init__(
self,
energy_key: str,
response_properties: List[str],
map_properties: Optional[Dict[str, str]] = None,
):
"""
Compute different response properties by taking derivatives of an energy model. See [#field1]_ for details.
Args:
energy_key (str): key indicating the energy property used for response calculations.
response_properties (list(str)): List of requested response properties.
map_properties (dict(str,str)): Dictionary for mapping property names. The keys are the names as computed
by the response layer (default `schnetpack.properties`), the values the
new names.
References:
-----------
.. [#field1] Gastegger, Schütt, Müller:
Machine learning of solvent effects on molecular spectra and reactions.
Chemical Science, 12(34), 11473-11483. 2021.
"""
super(Response, self).__init__()
for prop in response_properties:
if prop not in self.implemented_properties:
raise NotImplementedError(
"Property {:s} not implemented in response layer.".format(prop)
)
self.energy_key = energy_key
self.response_properties = response_properties
if map_properties is None:
self.map_properties = {}
else:
self.map_properties = map_properties
for prop in self.response_properties:
if prop not in self.map_properties:
self.map_properties[prop] = prop
self.model_outputs = list(self.map_properties.keys())
# Set up instructions for computing response properties and derivatives
(
basic_derivatives,
required_derivatives,
derivative_instructions,
graph_required,
) = self._construct_properties()
# Basic and required can not be merged
self.basic_derivatives = basic_derivatives
self.required_derivatives = required_derivatives
self.derivative_instructions = derivative_instructions
self.graph_required = graph_required
# Check whether basic graph is enough or higher level derivatives are necessary
self.basic_graph_required = len(self.basic_derivatives) != len(
[p for p in self.derivative_instructions if self.derivative_instructions[p]]
)
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
energy = inputs[self.energy_key]
# Compute base level derivatives
go: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
basic_derivatives = grad(
[energy],
[inputs[prop] for prop in self.basic_derivatives.values()],
grad_outputs=go,
create_graph=(self.basic_graph_required or self.training),
retain_graph=(self.basic_graph_required or self.training),
)
# Convert to dictionary
basic_derivatives = dict(zip(self.basic_derivatives.keys(), basic_derivatives))
results = {}
# ================================
# dE / dR
# ================================
if self.derivative_instructions["dEdR"]:
# basic distance derivatives
if properties.forces in self.response_properties:
results[properties.forces] = -basic_derivatives["dEdR"]
if self.derivative_instructions["d2EdR2"]:
d2EdR2 = derivative_from_atomic(
basic_derivatives["dEdR"],
inputs[properties.R],
inputs[properties.n_atoms],
create_graph=(self.graph_required["d2EdR2"] or self.training),
retain_graph=True,
)
results[properties.hessian] = d2EdR2
# ================================
# dE / ds
# ================================
if self.derivative_instructions["dEds"]:
stress = basic_derivatives["dEds"]
# TorchScript needs Tensor instead of Optional[Tensor]
if stress is None:
stress = torch.zeros_like(inputs[properties.cell])
cell = inputs[properties.cell]
volume = torch.sum(
cell[:, 0, :] * torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
dim=1,
keepdim=True,
)[:, :, None]
results[properties.stress] = stress / volume
# ================================
# dE / dF
# ================================
if self.derivative_instructions["dEdF"]:
dEdF = basic_derivatives["dEdF"]
results[properties.dipole_moment] = -basic_derivatives["dEdF"]
if self.derivative_instructions["d2EdFdR"]:
d2EdFdR = derivative_from_molecular(
-dEdF,
inputs[properties.R],
create_graph=(self.graph_required["d2EdFdR"] or self.training),
retain_graph=True,
)
results[properties.dipole_derivatives] = d2EdFdR
# Compute partial charges if requested
if properties.partial_charges in self.response_properties:
results[properties.partial_charges] = (
torch.einsum("bii->b", d2EdFdR) / 3.0
)
if self.derivative_instructions["d2EdF2"]:
d2EdF2 = derivative_from_molecular(
-dEdF,
inputs[properties.electric_field],
create_graph=(self.graph_required["d2EdF2"] or self.training),
retain_graph=True,
)
results[properties.polarizability] = d2EdF2
if self.derivative_instructions["d3EdF2dR"]:
d3EdF2dR = derivative_from_molecular(
d2EdF2,
inputs[properties.R],
create_graph=(self.graph_required["d3EdF2dR"] or self.training),
retain_graph=True,
)
results[properties.polarizability_derivatives] = d3EdF2dR
# ================================
# dE / dB
# ================================
if self.derivative_instructions["dEdB"]:
dEdB = basic_derivatives["dEdB"]
results["dEdB"] = dEdB
if self.derivative_instructions["d2EdBdI"]:
d2EdBdI = derivative_from_molecular(
dEdB,
inputs[properties.nuclear_magnetic_moments],
create_graph=(self.graph_required["d2EdBdI"] or self.training),
retain_graph=True,
)
results[properties.shielding] = d2EdBdI
# ================================
# dE / dI
# ================================
if self.derivative_instructions["dEdI"]:
dEdI = basic_derivatives["dEdI"]
results["dEdI"] = dEdI
if self.derivative_instructions["d2EdI2"]:
d2EdI2 = derivative_from_atomic(
dEdI,
inputs[properties.nuclear_magnetic_moments],
inputs[properties.n_atoms],
create_graph=(self.graph_required["d2EdI2"] or self.training),
retain_graph=True,
)
results[properties.nuclear_spin_coupling] = d2EdI2
for prop in self.map_properties:
inputs[self.map_properties[prop]] = results[prop]
return inputs
def _construct_properties(
self,
) -> Tuple[Dict[str, str], List[str], Dict[str, bool], Dict[str, bool]]:
"""
Routine for automatically determining the computational settings of the response
layer based on the requested response properties.
Based on the requested response properties, determine:
- which derivatives need to be computed
- which properties need to be enabled for gradient computation
- for which derivatives does a graph need to be constructed
Returns:
- dictionary of basic derivatives
- list of variables which need gradients
- dictionary of derivative instructions
- dictionary of required graphs
"""
derivative_instructions = {
"dEdR": False,
"d2EdR2": False,
"dEdF": False,
"d2EdFdR": False,
"d2EdF2": False,
"d3EdF2dR": False,
"dEdB": False,
"dEdI": False,
"d2EdBdI": False,
"d2EdI2": False,
"dEds": False,
}
graph_required = {
"dEdR": False,
"d2EdR2": False,
"dEdF": False,
"d2EdFdR": False,
"d2EdF2": False,
"d3EdF2dR": False,
"dEdB": False,
"dEdI": False,
"d2EdBdI": False,
"d2EdI2": False,
"dEds": False,
}
required_derivatives = set()
basic_derivatives = dict()
# position derivatives
if (properties.forces in self.response_properties) or (
properties.hessian in self.response_properties
):
derivative_instructions["dEdR"] = True
required_derivatives.add(properties.R)
basic_derivatives["dEdR"] = properties.R
if properties.hessian in self.response_properties:
graph_required["dEdR"] = True
derivative_instructions["d2EdR2"] = True
# strain derivatives
if properties.stress in self.response_properties:
derivative_instructions["dEds"] = True
required_derivatives.add(properties.strain)
basic_derivatives["dEds"] = properties.strain
# electric field derivatives
if (
(properties.dipole_moment in self.response_properties)
or (properties.polarizability in self.response_properties)
or (properties.dipole_derivatives in self.response_properties)
or (properties.polarizability_derivatives in self.response_properties)
or (properties.partial_charges in self.response_properties)
):
derivative_instructions["dEdF"] = True
required_derivatives.add(properties.electric_field)
basic_derivatives["dEdF"] = properties.electric_field
if (properties.dipole_derivatives in self.response_properties) or (
properties.partial_charges in self.response_properties
):
graph_required["dEdF"] = True
derivative_instructions["d2EdFdR"] = True
required_derivatives.add(properties.R)
if (properties.polarizability in self.response_properties) or (
properties.polarizability_derivatives in self.response_properties
):
graph_required["dEdF"] = True
derivative_instructions["d2EdF2"] = True
if properties.polarizability_derivatives in self.response_properties:
graph_required["d2EdF2"] = True
derivative_instructions["d3EdF2dR"] = True
required_derivatives.add(properties.R)
# magnetic moment derivatives
if properties.nuclear_spin_coupling in self.response_properties:
# First derivative
required_derivatives.add(properties.nuclear_magnetic_moments)
basic_derivatives["dEdI"] = properties.nuclear_magnetic_moments
derivative_instructions["dEdI"] = True
# Second derivative for couplings
graph_required["dEdI"] = True
derivative_instructions["d2EdI2"] = True
# magnetic field derivatives
if properties.shielding in self.response_properties:
# First derivative
required_derivatives.add(properties.magnetic_field)
basic_derivatives["dEdB"] = properties.magnetic_field
derivative_instructions["dEdB"] = True
# Second derivative for shielding
required_derivatives.add(properties.nuclear_magnetic_moments)
graph_required["dEdB"] = True
derivative_instructions["d2EdBdI"] = True
# Convert back to list
required_derivatives = list(required_derivatives)
return (
basic_derivatives,
required_derivatives,
derivative_instructions,
graph_required,
)
class Strain(nn.Module):
"""
This is required to calculate the stress as a response property.
Adds strain-dependence to relative atomic positions Rij and (optionally) to absolute
positions and unit cell.
"""
def forward(self, inputs: Dict[str, torch.Tensor]):
strain = torch.zeros_like(inputs[properties.cell])
strain.requires_grad_()
inputs[properties.strain] = strain
strain = strain.transpose(1, 2)
# strain cell
inputs[properties.cell] = inputs[properties.cell] + torch.matmul(
inputs[properties.cell], strain
)
# strain positions
idx_m = inputs[properties.idx_m]
strain_i = strain[idx_m]
inputs[properties.R] = inputs[properties.R] + torch.matmul(
inputs[properties.R][:, None, :], strain_i
).squeeze(1)
idx_i = inputs[properties.idx_i]
strain_ij = strain_i[idx_i]
inputs[properties.offsets] = inputs[properties.offsets] + torch.matmul(
inputs[properties.offsets][:, None, :], strain_ij
).squeeze(1)
return inputs