import torch class LayerNorm(torch.nn.LayerNorm): """ Layer normalization module. Args: nout (int): Output dim size. dim (int): Dimension to be normalized. """ def __init__(self, nout, dim=-1): """ Construct an LayerNorm object. """ super(LayerNorm, self).__init__(nout, eps=1e-12) self.dim = dim def forward(self, x): """ Apply layer normalization. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Normalized tensor. """ if self.dim == -1: return super(LayerNorm, self).forward(x) return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)