from typing import Union, Sequence, Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import schnetpack.nn as snn
from schnetpack.nn.activations import shifted_softplus
__all__ = ["build_mlp", "build_gated_equivariant_mlp"]
[docs]def build_mlp(
n_in: int,
n_out: int,
n_hidden: Optional[Union[int, Sequence[int]]] = None,
n_layers: int = 2,
activation: Callable = F.silu,
last_bias: bool = True,
last_zero_init: bool = False,
) -> nn.Module:
"""
Build multiple layer fully connected perceptron neural network.
Args:
n_in: number of input nodes.
n_out: number of output nodes.
n_hidden: number hidden layer nodes.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
n_layers: number of layers.
activation: activation function. All hidden layers would
the same activation function except the output layer that does not apply
any activation function.
"""
# get list of number of nodes in input, hidden & output layers
if n_hidden is None:
c_neurons = n_in
n_neurons = []
for i in range(n_layers):
n_neurons.append(c_neurons)
c_neurons = max(n_out, c_neurons // 2)
n_neurons.append(n_out)
else:
# get list of number of nodes hidden layers
if type(n_hidden) is int:
n_hidden = [n_hidden] * (n_layers - 1)
else:
n_hidden = list(n_hidden)
n_neurons = [n_in] + n_hidden + [n_out]
# assign a Dense layer (with activation function) to each hidden layer
layers = [
snn.Dense(n_neurons[i], n_neurons[i + 1], activation=activation)
for i in range(n_layers - 1)
]
# assign a Dense layer (without activation function) to the output layer
if last_zero_init:
layers.append(
snn.Dense(
n_neurons[-2],
n_neurons[-1],
activation=None,
weight_init=torch.nn.init.zeros_,
bias=last_bias,
)
)
else:
layers.append(
snn.Dense(n_neurons[-2], n_neurons[-1], activation=None, bias=last_bias)
)
# put all layers together to make the network
out_net = nn.Sequential(*layers)
return out_net
[docs]def build_gated_equivariant_mlp(
n_in: int,
n_out: int,
n_hidden: Optional[Union[int, Sequence[int]]] = None,
n_gating_hidden: Optional[Union[int, Sequence[int]]] = None,
n_layers: int = 2,
activation: Callable = F.silu,
sactivation: Callable = F.silu,
):
"""
Build neural network analog to MLP with `GatedEquivariantBlock`s instead of dense layers.
Args:
n_in: number of input nodes.
n_out: number of output nodes.
n_hidden: number hidden layer nodes.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
n_layers: number of layers.
activation: Activation function for gating function.
sactivation: Activation function for scalar outputs. All hidden layers would
the same activation function except the output layer that does not apply
any activation function.
"""
# get list of number of nodes in input, hidden & output layers
if n_hidden is None:
c_neurons = n_in
n_neurons = []
for i in range(n_layers):
n_neurons.append(c_neurons)
c_neurons = max(n_out, c_neurons // 2)
n_neurons.append(n_out)
else:
# get list of number of nodes hidden layers
if type(n_hidden) is int:
n_hidden = [n_hidden] * (n_layers - 1)
else:
n_hidden = list(n_hidden)
n_neurons = [n_in] + n_hidden + [n_out]
if n_gating_hidden is None:
n_gating_hidden = n_neurons[:-1]
elif type(n_gating_hidden) is int:
n_gating_hidden = [n_gating_hidden] * n_layers
else:
n_gating_hidden = list(n_gating_hidden)
# assign a GatedEquivariantBlock (with activation function) to each hidden layer
layers = [
snn.GatedEquivariantBlock(
n_sin=n_neurons[i],
n_vin=n_neurons[i],
n_sout=n_neurons[i + 1],
n_vout=n_neurons[i + 1],
n_hidden=n_gating_hidden[i],
activation=activation,
sactivation=sactivation,
)
for i in range(n_layers - 1)
]
# assign a GatedEquivariantBlock (without scalar activation function)
# to the output layer
layers.append(
snn.GatedEquivariantBlock(
n_sin=n_neurons[-2],
n_vin=n_neurons[-2],
n_sout=n_neurons[-1],
n_vout=n_neurons[-1],
n_hidden=n_gating_hidden[-1],
activation=activation,
sactivation=None,
)
)
# put all layers together to make the network
out_net = nn.Sequential(*layers)
return out_net
class Residual(nn.Module):
"""
Pre-activation residual block inspired by He, Kaiming, et al. "Identity
mappings in deep residual networks."
"""
def __init__(
self,
num_features: int,
activation: Union[Callable, nn.Module] = None,
bias: bool = True,
zero_init: bool = True,
) -> None:
"""
Args:
num_features: Dimensions of feature space.
activation: activation function
"""
super(Residual, self).__init__()
# initialize attributes
self.activation1 = activation # (num_features)
self.linear1 = nn.Linear(num_features, num_features, bias=bias)
self.activation2 = activation # (num_features)
self.linear2 = nn.Linear(num_features, num_features, bias=bias)
self.reset_parameters(bias, zero_init)
def reset_parameters(self, bias: bool = True, zero_init: bool = True) -> None:
"""Initialize parameters to compute an identity mapping."""
nn.init.orthogonal_(self.linear1.weight)
if zero_init:
nn.init.zeros_(self.linear2.weight)
else:
nn.init.orthogonal_(self.linear2.weight)
if bias:
nn.init.zeros_(self.linear1.bias)
nn.init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply residual block to input atomic features.
N: Number of atoms.
num_features: Dimensions of feature space.
Args:
x (FloatTensor [N, num_features]):Input feature representations of atoms.
Returns:
y (FloatTensor [N, num_features]): Output feature representations of atoms.
"""
y = self.activation1(x)
y = self.linear1(y)
y = self.activation2(y)
y = self.linear2(y)
return x + y
class ResidualStack(nn.Module):
"""
Stack of num_blocks pre-activation residual blocks evaluated in sequence.
"""
def __init__(
self,
num_features: int,
num_residual: int,
activation: Union[Callable, nn.Module],
bias: bool = True,
zero_init: bool = True,
) -> None:
"""
Args:
num_blocks: Number of residual blocks to be stacked in sequence.
num_features: Dimensions of feature space.
activation: activation function
"""
super(ResidualStack, self).__init__()
self.stack = nn.ModuleList(
[
Residual(num_features, activation, bias, zero_init)
for i in range(num_residual)
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies all residual blocks to input features in sequence.
N: Number of inputs.
num_features: Dimensions of feature space.
Args:
x (FloatTensor [N, num_features]): Input feature representations.
Returns:
y (FloatTensor [N, num_features]):Output feature representations.
"""
for residual in self.stack:
x = residual(x)
return x
class ResidualMLP(nn.Module):
"""Residual MLP with num_residual residual blocks."""
def __init__(
self,
num_features: int,
num_residual: int,
activation: Union[Callable, nn.Module],
bias: bool = True,
zero_init: bool = False,
):
"""
Args:
num_features: Dimensions of feature space.
num_residual: Number of residual blocks to be stacked in sequence.
activation: activation function
"""
super(ResidualMLP, self).__init__()
self.residual = ResidualStack(
num_features, num_residual, activation=activation, bias=bias, zero_init=True
)
self.linear = nn.Linear(num_features, num_features, bias=bias)
self.activation = activation
self.reset_parameters(bias, zero_init)
def reset_parameters(self, bias: bool = True, zero_init: bool = False) -> None:
if zero_init:
nn.init.zeros_(self.linear.weight)
else:
nn.init.orthogonal_(self.linear.weight)
if bias:
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.activation(self.residual(x)))