TheComputerMan commited on
Commit
d927b86
1 Parent(s): 9adcb78

Upload LayerNorm.py

Browse files
Files changed (1) hide show
  1. LayerNorm.py +36 -0
LayerNorm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class LayerNorm(torch.nn.LayerNorm):
9
+ """
10
+ Layer normalization module.
11
+
12
+ Args:
13
+ nout (int): Output dim size.
14
+ dim (int): Dimension to be normalized.
15
+ """
16
+
17
+ def __init__(self, nout, dim=-1):
18
+ """
19
+ Construct an LayerNorm object.
20
+ """
21
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
22
+ self.dim = dim
23
+
24
+ def forward(self, x):
25
+ """
26
+ Apply layer normalization.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Normalized tensor.
33
+ """
34
+ if self.dim == -1:
35
+ return super(LayerNorm, self).forward(x)
36
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)