Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2019 Shigeki Karita | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""Layer normalization module.""" | |
import torch | |
class LayerNorm(torch.nn.LayerNorm): | |
"""Layer normalization module. | |
:param int nout: output dim size | |
:param int dim: dimension to be normalized | |
""" | |
def __init__(self, nout, dim=-1): | |
"""Construct an LayerNorm object.""" | |
super(LayerNorm, self).__init__(nout, eps=1e-12) | |
self.dim = dim | |
def forward(self, x): | |
"""Apply layer normalization. | |
:param torch.Tensor x: input tensor | |
:return: layer normalized tensor | |
:rtype torch.Tensor | |
""" | |
if self.dim == -1: | |
return super(LayerNorm, self).forward(x) | |
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) | |