|
from pathlib import Path |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
from typeguard import check_argument_types |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet2.layers.abs_normalize import AbsNormalize |
|
from espnet2.layers.inversible_interface import InversibleInterface |
|
|
|
|
|
class GlobalMVN(AbsNormalize, InversibleInterface): |
|
"""Apply global mean and variance normalization |
|
|
|
TODO(kamo): Make this class portable somehow |
|
|
|
Args: |
|
stats_file: npy file |
|
norm_means: Apply mean normalization |
|
norm_vars: Apply var normalization |
|
eps: |
|
""" |
|
|
|
def __init__( |
|
self, |
|
stats_file: Union[Path, str], |
|
norm_means: bool = True, |
|
norm_vars: bool = True, |
|
eps: float = 1.0e-20, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
self.norm_means = norm_means |
|
self.norm_vars = norm_vars |
|
self.eps = eps |
|
stats_file = Path(stats_file) |
|
|
|
self.stats_file = stats_file |
|
stats = np.load(stats_file) |
|
if isinstance(stats, np.ndarray): |
|
|
|
count = stats[0].flatten()[-1] |
|
mean = stats[0, :-1] / count |
|
var = stats[1, :-1] / count - mean * mean |
|
else: |
|
|
|
count = stats["count"] |
|
sum_v = stats["sum"] |
|
sum_square_v = stats["sum_square"] |
|
mean = sum_v / count |
|
var = sum_square_v / count - mean * mean |
|
std = np.sqrt(np.maximum(var, eps)) |
|
|
|
self.register_buffer("mean", torch.from_numpy(mean)) |
|
self.register_buffer("std", torch.from_numpy(std)) |
|
|
|
def extra_repr(self): |
|
return ( |
|
f"stats_file={self.stats_file}, " |
|
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, ilens: torch.Tensor = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Forward function |
|
|
|
Args: |
|
x: (B, L, ...) |
|
ilens: (B,) |
|
""" |
|
if ilens is None: |
|
ilens = x.new_full([x.size(0)], x.size(1)) |
|
norm_means = self.norm_means |
|
norm_vars = self.norm_vars |
|
self.mean = self.mean.to(x.device, x.dtype) |
|
self.std = self.std.to(x.device, x.dtype) |
|
mask = make_pad_mask(ilens, x, 1) |
|
|
|
|
|
if norm_means: |
|
if x.requires_grad: |
|
x = x - self.mean |
|
else: |
|
x -= self.mean |
|
if x.requires_grad: |
|
x = x.masked_fill(mask, 0.0) |
|
else: |
|
x.masked_fill_(mask, 0.0) |
|
|
|
if norm_vars: |
|
x /= self.std |
|
|
|
return x, ilens |
|
|
|
def inverse( |
|
self, x: torch.Tensor, ilens: torch.Tensor = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if ilens is None: |
|
ilens = x.new_full([x.size(0)], x.size(1)) |
|
norm_means = self.norm_means |
|
norm_vars = self.norm_vars |
|
self.mean = self.mean.to(x.device, x.dtype) |
|
self.std = self.std.to(x.device, x.dtype) |
|
mask = make_pad_mask(ilens, x, 1) |
|
|
|
if x.requires_grad: |
|
x = x.masked_fill(mask, 0.0) |
|
else: |
|
x.masked_fill_(mask, 0.0) |
|
|
|
if norm_vars: |
|
x *= self.std |
|
|
|
|
|
if norm_means: |
|
x += self.mean |
|
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) |
|
return x, ilens |
|
|