from typing import Callable, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import zeros_
__all__ = ["Dense"]
[docs]class Dense(nn.Linear):
r"""Fully connected linear layer with activation function.
.. math::
y = activation(x W^T + b)
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
activation: Union[Callable, nn.Module] = None,
weight_init: Callable = xavier_uniform_,
bias_init: Callable = zeros_,
):
"""
Args:
in_features: number of input feature :math:`x`.
out_features: number of output features :math:`y`.
bias: If False, the layer will not adapt bias :math:`b`.
activation: if None, no activation function is used.
weight_init: weight initializer from current weight.
bias_init: bias initializer from current bias.
"""
self.weight_init = weight_init
self.bias_init = bias_init
super(Dense, self).__init__(in_features, out_features, bias)
self.activation = activation
if self.activation is None:
self.activation = nn.Identity()
def reset_parameters(self):
self.weight_init(self.weight)
if self.bias is not None:
self.bias_init(self.bias)
def forward(self, input: torch.Tensor):
y = F.linear(input, self.weight, self.bias)
y = self.activation(y)
return y