File size: 5,619 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import math
from collections.abc import Callable
import torch
import torch.nn as nn
class Lambda(nn.Module):
"""
Overview:
A custom lambda module for constructing custom layers.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, f: Callable):
"""
Overview:
Initialize the lambda module with a given function.
Arguments:
- f (:obj:`Callable`): a python function
"""
super(Lambda, self).__init__()
self.f = f
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the function of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
"""
return self.f(x)
class GLU(nn.Module):
"""
Overview:
Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in
[Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None:
"""
Overview:
Initialize the GLU module.
Arguments:
- input_dim (:obj:`int`): The dimension of the input tensor.
- output_dim (:obj:`int`): The dimension of the output tensor.
- context_dim (:obj:`int`): The dimension of the context tensor.
- input_type (:obj:`str`): The type of input, now supports ['fc', 'conv2d']
"""
super(GLU, self).__init__()
assert (input_type in ['fc', 'conv2d'])
if input_type == 'fc':
self.layer1 = nn.Linear(context_dim, input_dim)
self.layer2 = nn.Linear(input_dim, output_dim)
elif input_type == 'conv2d':
self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0)
self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0)
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the GLU transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
- context (:obj:`torch.Tensor`): The context tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after GLU transformation.
"""
gate = self.layer1(context)
gate = torch.sigmoid(gate)
x = gate * x
x = self.layer2(x)
return x
class Swish(nn.Module):
"""
Overview:
Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer
to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self):
"""
Overview:
Initialize the Swish module.
"""
super(Swish, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the Swish transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after Swish transformation.
"""
return x * torch.sigmoid(x)
class GELU(nn.Module):
"""
Overview:
Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT.
For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self):
"""
Overview:
Initialize the GELU module.
"""
super(GELU, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute the GELU transformation of the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor.
Returns:
- x (:obj:`torch.Tensor`): The output tensor after GELU transformation.
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def build_activation(activation: str, inplace: bool = None) -> nn.Module:
"""
Overview:
Build and return the activation module according to the given type.
Arguments:
- activation (:obj:`str`): The type of activation module, now supports \
['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity'].
- inplace (Optional[:obj:`bool`): Execute the operation in-place in activation, defaults to None.
Returns:
- act_func (:obj:`nn.module`): The corresponding activation module.
"""
if inplace is not None:
assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
else:
inplace = False
act_func = {
'relu': nn.ReLU(inplace=inplace),
'glu': GLU,
'prelu': nn.PReLU(),
'swish': Swish(),
'gelu': GELU(),
"tanh": nn.Tanh(),
"sigmoid": nn.Sigmoid(),
"softplus": nn.Softplus(),
"elu": nn.ELU(),
"square": Lambda(lambda x: x ** 2),
"identity": Lambda(lambda x: x),
}
if activation.lower() in act_func.keys():
return act_func[activation]
else:
raise KeyError("invalid key for activation: {}".format(activation))
|