|  | from typing import List | 
					
						
						|  | from typing import Tuple | 
					
						
						|  | from typing import Union | 
					
						
						|  |  | 
					
						
						|  | import librosa | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | from torch_complex.tensor import ComplexTensor | 
					
						
						|  |  | 
					
						
						|  | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FeatureTransform(torch.nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  |  | 
					
						
						|  | fs: int = 16000, | 
					
						
						|  | n_fft: int = 512, | 
					
						
						|  | n_mels: int = 80, | 
					
						
						|  | fmin: float = 0.0, | 
					
						
						|  | fmax: float = None, | 
					
						
						|  |  | 
					
						
						|  | stats_file: str = None, | 
					
						
						|  | apply_uttmvn: bool = True, | 
					
						
						|  | uttmvn_norm_means: bool = True, | 
					
						
						|  | uttmvn_norm_vars: bool = False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.apply_uttmvn = apply_uttmvn | 
					
						
						|  |  | 
					
						
						|  | self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) | 
					
						
						|  | self.stats_file = stats_file | 
					
						
						|  | if stats_file is not None: | 
					
						
						|  | self.global_mvn = GlobalMVN(stats_file) | 
					
						
						|  | else: | 
					
						
						|  | self.global_mvn = None | 
					
						
						|  |  | 
					
						
						|  | if self.apply_uttmvn is not None: | 
					
						
						|  | self.uttmvn = UtteranceMVN( | 
					
						
						|  | norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.uttmvn = None | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.LongTensor]: | 
					
						
						|  |  | 
					
						
						|  | if x.dim() not in (3, 4): | 
					
						
						|  | raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") | 
					
						
						|  | if not torch.is_tensor(ilens): | 
					
						
						|  | ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) | 
					
						
						|  |  | 
					
						
						|  | if x.dim() == 4: | 
					
						
						|  |  | 
					
						
						|  | if self.training: | 
					
						
						|  |  | 
					
						
						|  | ch = np.random.randint(x.size(2)) | 
					
						
						|  | h = x[:, :, ch, :] | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | h = x[:, :, 0, :] | 
					
						
						|  | else: | 
					
						
						|  | h = x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | h = h.real ** 2 + h.imag ** 2 | 
					
						
						|  |  | 
					
						
						|  | h, _ = self.logmel(h, ilens) | 
					
						
						|  | if self.stats_file is not None: | 
					
						
						|  | h, _ = self.global_mvn(h, ilens) | 
					
						
						|  | if self.apply_uttmvn: | 
					
						
						|  | h, _ = self.uttmvn(h, ilens) | 
					
						
						|  |  | 
					
						
						|  | return h, ilens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LogMel(torch.nn.Module): | 
					
						
						|  | """Convert STFT to fbank feats | 
					
						
						|  |  | 
					
						
						|  | The arguments is same as librosa.filters.mel | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | fs: number > 0 [scalar] sampling rate of the incoming signal | 
					
						
						|  | n_fft: int > 0 [scalar] number of FFT components | 
					
						
						|  | n_mels: int > 0 [scalar] number of Mel bands to generate | 
					
						
						|  | fmin: float >= 0 [scalar] lowest frequency (in Hz) | 
					
						
						|  | fmax: float >= 0 [scalar] highest frequency (in Hz). | 
					
						
						|  | If `None`, use `fmax = fs / 2.0` | 
					
						
						|  | htk: use HTK formula instead of Slaney | 
					
						
						|  | norm: {None, 1, np.inf} [scalar] | 
					
						
						|  | if 1, divide the triangular mel weights by the width of the mel band | 
					
						
						|  | (area normalization).  Otherwise, leave all the triangles aiming for | 
					
						
						|  | a peak value of 1.0 | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | fs: int = 16000, | 
					
						
						|  | n_fft: int = 512, | 
					
						
						|  | n_mels: int = 80, | 
					
						
						|  | fmin: float = 0.0, | 
					
						
						|  | fmax: float = None, | 
					
						
						|  | htk: bool = False, | 
					
						
						|  | norm=1, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | _mel_options = dict( | 
					
						
						|  | sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm | 
					
						
						|  | ) | 
					
						
						|  | self.mel_options = _mel_options | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | melmat = librosa.filters.mel(**_mel_options) | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) | 
					
						
						|  |  | 
					
						
						|  | def extra_repr(self): | 
					
						
						|  | return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, feat: torch.Tensor, ilens: torch.LongTensor | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.LongTensor]: | 
					
						
						|  |  | 
					
						
						|  | mel_feat = torch.matmul(feat, self.melmat) | 
					
						
						|  |  | 
					
						
						|  | logmel_feat = (mel_feat + 1e-20).log() | 
					
						
						|  |  | 
					
						
						|  | logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) | 
					
						
						|  | return logmel_feat, ilens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GlobalMVN(torch.nn.Module): | 
					
						
						|  | """Apply global mean and variance normalization | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | stats_file(str): npy file of 1-dim array or text file. | 
					
						
						|  | From the _first element to | 
					
						
						|  | the {(len(array) - 1) / 2}th element are treated as | 
					
						
						|  | the sum of features, | 
					
						
						|  | and the rest excluding the last elements are | 
					
						
						|  | treated as the sum of the square value of features, | 
					
						
						|  | and the last elements eqauls to the number of samples. | 
					
						
						|  | std_floor(float): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | stats_file: str, | 
					
						
						|  | norm_means: bool = True, | 
					
						
						|  | norm_vars: bool = True, | 
					
						
						|  | eps: float = 1.0e-20, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.norm_means = norm_means | 
					
						
						|  | self.norm_vars = norm_vars | 
					
						
						|  |  | 
					
						
						|  | self.stats_file = stats_file | 
					
						
						|  | stats = np.load(stats_file) | 
					
						
						|  |  | 
					
						
						|  | stats = stats.astype(float) | 
					
						
						|  | assert (len(stats) - 1) % 2 == 0, stats.shape | 
					
						
						|  |  | 
					
						
						|  | count = stats.flatten()[-1] | 
					
						
						|  | mean = stats[: (len(stats) - 1) // 2] / count | 
					
						
						|  | var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean | 
					
						
						|  | std = np.maximum(np.sqrt(var), eps) | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) | 
					
						
						|  | self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) | 
					
						
						|  |  | 
					
						
						|  | def extra_repr(self): | 
					
						
						|  | return ( | 
					
						
						|  | f"stats_file={self.stats_file}, " | 
					
						
						|  | f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, x: torch.Tensor, ilens: torch.LongTensor | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.LongTensor]: | 
					
						
						|  |  | 
					
						
						|  | if self.norm_means: | 
					
						
						|  | x += self.bias.type_as(x) | 
					
						
						|  | x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) | 
					
						
						|  |  | 
					
						
						|  | if self.norm_vars: | 
					
						
						|  | x *= self.scale.type_as(x) | 
					
						
						|  | return x, ilens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UtteranceMVN(torch.nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.norm_means = norm_means | 
					
						
						|  | self.norm_vars = norm_vars | 
					
						
						|  | self.eps = eps | 
					
						
						|  |  | 
					
						
						|  | def extra_repr(self): | 
					
						
						|  | return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, x: torch.Tensor, ilens: torch.LongTensor | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.LongTensor]: | 
					
						
						|  | return utterance_mvn( | 
					
						
						|  | x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def utterance_mvn( | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | ilens: torch.LongTensor, | 
					
						
						|  | norm_means: bool = True, | 
					
						
						|  | norm_vars: bool = False, | 
					
						
						|  | eps: float = 1.0e-20, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.LongTensor]: | 
					
						
						|  | """Apply utterance mean and variance normalization | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: (B, T, D), assumed zero padded | 
					
						
						|  | ilens: (B, T, D) | 
					
						
						|  | norm_means: | 
					
						
						|  | norm_vars: | 
					
						
						|  | eps: | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | ilens_ = ilens.type_as(x) | 
					
						
						|  |  | 
					
						
						|  | mean = x.sum(dim=1) / ilens_[:, None] | 
					
						
						|  |  | 
					
						
						|  | if norm_means: | 
					
						
						|  | x -= mean[:, None, :] | 
					
						
						|  | x_ = x | 
					
						
						|  | else: | 
					
						
						|  | x_ = x - mean[:, None, :] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) | 
					
						
						|  | if norm_vars: | 
					
						
						|  | var = x_.pow(2).sum(dim=1) / ilens_[:, None] | 
					
						
						|  | var = torch.clamp(var, min=eps) | 
					
						
						|  | x /= var.sqrt()[:, None, :] | 
					
						
						|  | x_ = x | 
					
						
						|  | return x_, ilens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def feature_transform_for(args, n_fft): | 
					
						
						|  | return FeatureTransform( | 
					
						
						|  |  | 
					
						
						|  | fs=args.fbank_fs, | 
					
						
						|  | n_fft=n_fft, | 
					
						
						|  | n_mels=args.n_mels, | 
					
						
						|  | fmin=args.fbank_fmin, | 
					
						
						|  | fmax=args.fbank_fmax, | 
					
						
						|  |  | 
					
						
						|  | stats_file=args.stats_file, | 
					
						
						|  | apply_uttmvn=args.apply_uttmvn, | 
					
						
						|  | uttmvn_norm_means=args.uttmvn_norm_means, | 
					
						
						|  | uttmvn_norm_vars=args.uttmvn_norm_vars, | 
					
						
						|  | ) | 
					
						
						|  |  |