Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Layer normalization module.""" | |
| import torch | |
| import torch.nn as nn | |
| class LayerNorm(torch.nn.LayerNorm): | |
| """Layer normalization module. | |
| Args: | |
| nout (int): Output dim size. | |
| dim (int): Dimension to be normalized. | |
| """ | |
| def __init__(self, nout, dim=-1): | |
| """Construct an LayerNorm object.""" | |
| super(LayerNorm, self).__init__(nout, eps=1e-12) | |
| self.dim = dim | |
| def forward(self, x): | |
| """Apply layer normalization. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Normalized tensor. | |
| """ | |
| if self.dim == -1: | |
| return super(LayerNorm, self).forward(x) | |
| return ( | |
| super(LayerNorm, self) | |
| .forward(x.transpose(self.dim, -1)) | |
| .transpose(self.dim, -1) | |
| ) | |
| class GlobalLayerNorm(nn.Module): | |
| """Calculate Global Layer Normalization. | |
| Arguments | |
| --------- | |
| dim : (int or list or torch.Size) | |
| Input shape from an expected input of size. | |
| eps : float | |
| A value added to the denominator for numerical stability. | |
| elementwise_affine : bool | |
| A boolean value that when set to True, | |
| this module has learnable per-element affine parameters | |
| initialized to ones (for weights) and zeros (for biases). | |
| Example | |
| ------- | |
| >>> x = torch.randn(5, 10, 20) | |
| >>> GLN = GlobalLayerNorm(10, 3) | |
| >>> x_norm = GLN(x) | |
| """ | |
| def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): | |
| super(GlobalLayerNorm, self).__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| if shape == 3: | |
| self.weight = nn.Parameter(torch.ones(self.dim, 1)) | |
| self.bias = nn.Parameter(torch.zeros(self.dim, 1)) | |
| if shape == 4: | |
| self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) | |
| self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.register_parameter("bias", None) | |
| def forward(self, x): | |
| """Returns the normalized tensor. | |
| Arguments | |
| --------- | |
| x : torch.Tensor | |
| Tensor of size [N, C, K, S] or [N, C, L]. | |
| """ | |
| # x = N x C x K x S or N x C x L | |
| # N x 1 x 1 | |
| # cln: mean,var N x 1 x K x S | |
| # gln: mean,var N x 1 x 1 | |
| if x.dim() == 3: | |
| mean = torch.mean(x, (1, 2), keepdim=True) | |
| var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) | |
| if self.elementwise_affine: | |
| x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias | |
| else: | |
| x = (x - mean) / torch.sqrt(var + self.eps) | |
| if x.dim() == 4: | |
| mean = torch.mean(x, (1, 2, 3), keepdim=True) | |
| var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) | |
| if self.elementwise_affine: | |
| x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias | |
| else: | |
| x = (x - mean) / torch.sqrt(var + self.eps) | |
| return x | |
| class CumulativeLayerNorm(nn.LayerNorm): | |
| """Calculate Cumulative Layer Normalization. | |
| Arguments | |
| --------- | |
| dim : int | |
| Dimension that you want to normalize. | |
| elementwise_affine : True | |
| Learnable per-element affine parameters. | |
| Example | |
| ------- | |
| >>> x = torch.randn(5, 10, 20) | |
| >>> CLN = CumulativeLayerNorm(10) | |
| >>> x_norm = CLN(x) | |
| """ | |
| def __init__(self, dim, elementwise_affine=True): | |
| super(CumulativeLayerNorm, self).__init__( | |
| dim, elementwise_affine=elementwise_affine, eps=1e-8 | |
| ) | |
| def forward(self, x): | |
| """Returns the normalized tensor. | |
| Arguments | |
| --------- | |
| x : torch.Tensor | |
| Tensor size [N, C, K, S] or [N, C, L] | |
| """ | |
| # x: N x C x K x S or N x C x L | |
| # N x K x S x C | |
| if x.dim() == 4: | |
| x = x.permute(0, 2, 3, 1).contiguous() | |
| # N x K x S x C == only channel norm | |
| x = super().forward(x) | |
| # N x C x K x S | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| if x.dim() == 3: | |
| x = torch.transpose(x, 1, 2) | |
| # N x L x C == only channel norm | |
| x = super().forward(x) | |
| # N x C x L | |
| x = torch.transpose(x, 1, 2) | |
| return x | |
| class ScaleNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.scale = dim**-0.5 | |
| self.eps = eps | |
| self.g = nn.Parameter(torch.ones(1)) | |
| def forward(self, x): | |
| norm = torch.norm(x, dim=-1, keepdim=True) * self.scale | |
| return x / norm.clamp(min=self.eps) * self.g | |