# Written by Shigeki Karita, 2019 # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Adapted by Florian Lux, 2021 import torch class LayerNorm(torch.nn.LayerNorm): """ Layer normalization module. Args: nout (int): Output dim size. dim (int): Dimension to be normalized. """ def __init__(self, nout, dim=-1, eps=1e-12): """ Construct an LayerNorm object. """ super(LayerNorm, self).__init__(nout, eps=eps) self.dim = dim def forward(self, x): """ Apply layer normalization. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Normalized tensor. """ if self.dim == -1: return super(LayerNorm, self).forward(x) return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)