Spaces:
Sleeping
Sleeping
| # Copyright 2019 Jian Wu | |
| # License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import math | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as tf | |
| import librosa.filters as filters | |
| from typing import Optional, Tuple | |
| from distutils.version import LooseVersion | |
| EPSILON = float(np.finfo(np.float32).eps) | |
| TORCH_VERSION = th.__version__ | |
| if TORCH_VERSION >= LooseVersion("1.7"): | |
| from torch.fft import fft as fft_func | |
| else: | |
| pass | |
| def export_jit(transform: nn.Module) -> nn.Module: | |
| """ | |
| Export transform module for inference | |
| """ | |
| export_out = [module for module in transform if module.exportable()] | |
| return nn.Sequential(*export_out) | |
| def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor: | |
| """ | |
| Return window coefficient | |
| Args: | |
| wnd: window name | |
| frame_len: length of the frame | |
| """ | |
| def sqrthann(frame_len, periodic=True): | |
| return th.hann_window(frame_len, periodic=periodic) ** 0.5 | |
| if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]: | |
| raise RuntimeError(f"Unknown window type: {wnd}") | |
| wnd_tpl = { | |
| "sqrthann": sqrthann, | |
| "hann": th.hann_window, | |
| "hamm": th.hamming_window, | |
| "blackman": th.blackman_window, | |
| "bartlett": th.bartlett_window, | |
| "rect": th.ones, | |
| } | |
| if wnd != "rect": | |
| # match with librosa | |
| c = wnd_tpl[wnd](frame_len, periodic=True) | |
| else: | |
| c = wnd_tpl[wnd](frame_len) | |
| return c.to(device) | |
| def init_kernel( | |
| frame_len: int, | |
| frame_hop: int, | |
| window: th.Tensor, | |
| round_pow_of_two: bool = True, | |
| normalized: bool = False, | |
| inverse: bool = False, | |
| mode: str = "librosa", | |
| ) -> Tuple[th.Tensor, th.Tensor]: | |
| """ | |
| Return STFT kernels | |
| Args: | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| window: window tensor | |
| round_pow_of_two: if true, choose round(#power_of_two) as the FFT size | |
| normalized: return normalized DFT matrix | |
| inverse: return iDFT matrix | |
| mode: framing mode (librosa or kaldi) | |
| """ | |
| if mode not in ["librosa", "kaldi"]: | |
| raise ValueError(f"Unsupported mode: {mode}") | |
| # FFT size: B | |
| if round_pow_of_two or mode == "kaldi": | |
| fft_size = 2 ** math.ceil(math.log2(frame_len)) | |
| else: | |
| fft_size = frame_len | |
| # center padding window if needed | |
| if mode == "librosa" and fft_size != frame_len: | |
| lpad = (fft_size - frame_len) // 2 | |
| window = tf.pad(window, (lpad, fft_size - frame_len - lpad)) | |
| if normalized: | |
| # make K^H * K = I | |
| S = fft_size ** 0.5 | |
| else: | |
| S = 1 | |
| # W x B x 2 | |
| if TORCH_VERSION >= LooseVersion("1.7"): | |
| K = fft_func(th.eye(fft_size) / S, dim=-1) | |
| K = th.stack([K.real, K.imag], dim=-1) | |
| else: | |
| I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1) | |
| K = th.fft(I / S, 1) | |
| if mode == "kaldi": | |
| K = K[:frame_len] | |
| if inverse and not normalized: | |
| # to make K^H * K = I | |
| K = K / fft_size | |
| # 2 x B x W | |
| K = th.transpose(K, 0, 2) | |
| # 2B x 1 x W | |
| K = th.reshape(K, (fft_size * 2, 1, K.shape[-1])) | |
| return K.to(window.device), window | |
| def mel_filter( | |
| frame_len: int, | |
| round_pow_of_two: bool = True, | |
| num_bins: Optional[int] = None, | |
| sr: int = 16000, | |
| num_mels: int = 80, | |
| fmin: float = 0.0, | |
| fmax: Optional[float] = None, | |
| norm: bool = False, | |
| ) -> th.Tensor: | |
| """ | |
| Return mel filter coefficients | |
| Args: | |
| frame_len: length of the frame | |
| round_pow_of_two: if true, choose round(#power_of_two) as the FFT size | |
| num_bins: number of the frequency bins produced by STFT | |
| num_mels: number of the mel bands | |
| fmin: lowest frequency (in Hz) | |
| fmax: highest frequency (in Hz) | |
| norm: normalize the mel filter coefficients | |
| """ | |
| # FFT points | |
| if num_bins is None: | |
| N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len | |
| else: | |
| N = (num_bins - 1) * 2 | |
| # fmin & fmax | |
| freq_upper = sr // 2 | |
| if fmax is None: | |
| fmax = freq_upper | |
| else: | |
| fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper) | |
| fmin = max(0, fmin) | |
| # mel filter coefficients | |
| mel = filters.mel( | |
| sr, | |
| N, | |
| n_mels=num_mels, | |
| fmax=fmax, | |
| fmin=fmin, | |
| htk=True, | |
| norm="slaney" if norm else None, | |
| ) | |
| # num_mels x (N // 2 + 1) | |
| return th.tensor(mel, dtype=th.float32) | |
| def speed_perturb_filter( | |
| src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64 | |
| ) -> th.Tensor: | |
| """ | |
| Return speed perturb filters, reference: | |
| https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py | |
| Args: | |
| src_sr: sample rate of the source signal | |
| dst_sr: sample rate of the target signal | |
| Return: | |
| weight (Tensor): coefficients of the filter | |
| """ | |
| if src_sr == dst_sr: | |
| raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}") | |
| gcd = math.gcd(src_sr, dst_sr) | |
| src_sr = src_sr // gcd | |
| dst_sr = dst_sr // gcd | |
| if src_sr == 1 or dst_sr == 1: | |
| raise ValueError("do not support integer downsample/upsample") | |
| zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio | |
| padding = 1 + int(num_zeros / zeros_per_block) | |
| # dst_sr x src_sr x K | |
| times = ( | |
| np.arange(dst_sr)[:, None, None] / float(dst_sr) | |
| - np.arange(src_sr)[None, :, None] / float(src_sr) | |
| - np.arange(2 * padding + 1)[None, None, :] | |
| + padding | |
| ) | |
| window = np.heaviside(1 - np.abs(times / padding), 0.0) * ( | |
| 0.5 + 0.5 * np.cos(times / padding * math.pi) | |
| ) | |
| weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr) | |
| return th.tensor(weight, dtype=th.float32) | |
| def splice_feature( | |
| feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat" | |
| ) -> th.Tensor: | |
| """ | |
| Splice feature | |
| Args: | |
| feats (Tensor): N x ... x T x F, original feature | |
| lctx: left context | |
| rctx: right context | |
| op: operator on feature context | |
| Return: | |
| splice (Tensor): feature with context padded | |
| """ | |
| if lctx + rctx == 0: | |
| return feats | |
| if op not in ["cat", "stack"]: | |
| raise ValueError(f"Unknown op for feature splicing: {op}") | |
| # [N x ... x T x F, ...] | |
| ctx = [] | |
| T = feats.shape[-2] | |
| for c in range(-lctx, rctx + 1): | |
| idx = th.arange(c, c + T, device=feats.device, dtype=th.int64) | |
| idx = th.clamp(idx, min=0, max=T - 1) | |
| ctx.append(th.index_select(feats, -2, idx)) | |
| if op == "cat": | |
| # N x ... x T x FD | |
| splice = th.cat(ctx, -1) | |
| else: | |
| # N x ... x T x F x D | |
| splice = th.stack(ctx, -1) | |
| return splice | |
| def _forward_stft( | |
| wav: th.Tensor, | |
| kernel: th.Tensor, | |
| window: th.Tensor, | |
| return_polar: bool = False, | |
| pre_emphasis: float = 0, | |
| frame_hop: int = 256, | |
| onesided: bool = False, | |
| center: bool = False, | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| STFT function implemented by conv1d (not efficient, but we don't care during training) | |
| Args: | |
| wav (Tensor): N x (C) x S | |
| kernel (Tensor): STFT transform kernels, from init_kernel(...) | |
| return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor | |
| pre_emphasis: factor of preemphasis | |
| frame_hop: frame hop size in number samples | |
| onesided: return half FFT bins | |
| center: if true, we assumed to have centered frames | |
| Return: | |
| transform (Tensor): STFT transform results | |
| """ | |
| wav_dim = wav.dim() | |
| if wav_dim not in [2, 3]: | |
| raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") | |
| # if N x S, reshape N x 1 x S | |
| # else: reshape NC x 1 x S | |
| N, S = wav.shape[0], wav.shape[-1] | |
| wav = wav.view(-1, 1, S) | |
| # NC x 1 x S+2P | |
| if center: | |
| pad = kernel.shape[-1] // 2 | |
| # NOTE: match with librosa | |
| wav = tf.pad(wav, (pad, pad), mode="reflect") | |
| # STFT | |
| kernel = kernel * window | |
| if pre_emphasis > 0: | |
| # NC x W x T | |
| frames = tf.unfold( | |
| wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0 | |
| ) | |
| # follow Kaldi's Preemphasize | |
| frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1] | |
| frames[:, 0] *= 1 - pre_emphasis | |
| # 1 x 2B x W, NC x W x T, NC x 2B x T | |
| packed = th.matmul(kernel[:, 0][None, ...], frames) | |
| else: | |
| packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0) | |
| # NC x 2B x T => N x C x 2B x T | |
| if wav_dim == 3: | |
| packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1]) | |
| # N x (C) x B x T | |
| real, imag = th.chunk(packed, 2, dim=-2) | |
| # N x (C) x B/2+1 x T | |
| if onesided: | |
| num_bins = kernel.shape[0] // 4 + 1 | |
| real = real[..., :num_bins, :] | |
| imag = imag[..., :num_bins, :] | |
| if return_polar: | |
| mag = (real ** 2 + imag ** 2 + eps) ** 0.5 | |
| pha = th.atan2(imag, real) | |
| return th.stack([mag, pha], dim=-1) | |
| else: | |
| return th.stack([real, imag], dim=-1) | |
| def _inverse_stft( | |
| transform: th.Tensor, | |
| kernel: th.Tensor, | |
| window: th.Tensor, | |
| return_polar: bool = False, | |
| frame_hop: int = 256, | |
| onesided: bool = False, | |
| center: bool = False, | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| iSTFT function implemented by conv1d | |
| Args: | |
| transform (Tensor): STFT transform results | |
| kernel (Tensor): STFT transform kernels, from init_kernel(...) | |
| return_polar (bool): keep same with the one in _forward_stft | |
| frame_hop: frame hop size in number samples | |
| onesided: return half FFT bins | |
| center: used in _forward_stft | |
| Return: | |
| wav (Tensor), N x S | |
| """ | |
| # (N) x F x T x 2 | |
| transform_dim = transform.dim() | |
| # if F x T x 2, reshape 1 x F x T x 2 | |
| if transform_dim == 3: | |
| transform = th.unsqueeze(transform, 0) | |
| if transform_dim != 4: | |
| raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") | |
| if return_polar: | |
| real = transform[..., 0] * th.cos(transform[..., 1]) | |
| imag = transform[..., 0] * th.sin(transform[..., 1]) | |
| else: | |
| real, imag = transform[..., 0], transform[..., 1] | |
| if onesided: | |
| # [self.num_bins - 2, ..., 1] | |
| reverse = range(kernel.shape[0] // 4 - 1, 0, -1) | |
| # extend matrix: N x B x T | |
| real = th.cat([real, real[:, reverse]], 1) | |
| imag = th.cat([imag, -imag[:, reverse]], 1) | |
| # pack: N x 2B x T | |
| packed = th.cat([real, imag], dim=1) | |
| # N x 1 x T | |
| wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0) | |
| # normalized audio samples | |
| # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171 | |
| num_frames = packed.shape[-1] | |
| win_length = window.shape[0] | |
| # W x T | |
| win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1) | |
| # Do OLA on windows | |
| # v1) | |
| I = th.eye(win_length, device=win.device)[:, None] | |
| denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0) | |
| # v2) | |
| # num_samples = (num_frames - 1) * frame_hop + win_length | |
| # denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1), | |
| # stride=frame_hop)[..., 0] | |
| if center: | |
| pad = kernel.shape[-1] // 2 | |
| wav = wav[..., pad:-pad] | |
| denorm = denorm[..., pad:-pad] | |
| wav = wav / (denorm + eps) | |
| # N x S | |
| return wav.squeeze(1) | |
| def _pytorch_stft( | |
| wav: th.Tensor, | |
| frame_len: int, | |
| frame_hop: int, | |
| n_fft: int = 512, | |
| return_polar: bool = False, | |
| window: str = "sqrthann", | |
| normalized: bool = False, | |
| onesided: bool = True, | |
| center: bool = False, | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| Wrapper of PyTorch STFT function | |
| Args: | |
| wav (Tensor): source audio signal | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| n_fft: number of the FFT size | |
| return_polar: return the results in polar coordinate | |
| window: window tensor | |
| center: same definition with the parameter in librosa.stft | |
| normalized: use normalized DFT kernel | |
| onesided: output onesided STFT | |
| Return: | |
| transform (Tensor), STFT transform results | |
| """ | |
| if TORCH_VERSION < LooseVersion("1.7"): | |
| raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") | |
| wav_dim = wav.dim() | |
| if wav_dim not in [2, 3]: | |
| raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") | |
| # if N x C x S, reshape NC x S | |
| wav = wav.view(-1, wav.shape[-1]) | |
| # STFT: N x F x T x 2 | |
| stft = th.stft( | |
| wav, | |
| n_fft, | |
| hop_length=frame_hop, | |
| win_length=window.shape[-1], | |
| window=window, | |
| center=center, | |
| normalized=normalized, | |
| onesided=onesided, | |
| return_complex=False, | |
| ) | |
| if wav_dim == 3: | |
| N, F, T, _ = stft.shape | |
| stft = stft.view(N, -1, F, T, 2) | |
| # N x (C) x F x T x 2 | |
| if not return_polar: | |
| return stft | |
| # N x (C) x F x T | |
| real, imag = stft[..., 0], stft[..., 1] | |
| mag = (real ** 2 + imag ** 2 + eps) ** 0.5 | |
| pha = th.atan2(imag, real) | |
| return th.stack([mag, pha], dim=-1) | |
| def _pytorch_istft( | |
| transform: th.Tensor, | |
| frame_len: int, | |
| frame_hop: int, | |
| window: th.Tensor, | |
| n_fft: int = 512, | |
| return_polar: bool = False, | |
| normalized: bool = False, | |
| onesided: bool = True, | |
| center: bool = False, | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| Wrapper of PyTorch iSTFT function | |
| Args: | |
| transform (Tensor): results of STFT | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| window: window tensor | |
| n_fft: number of the FFT size | |
| return_polar: keep same with _pytorch_stft | |
| center: same definition with the parameter in librosa.stft | |
| normalized: use normalized DFT kernel | |
| onesided: output onesided STFT | |
| Return: | |
| wav (Tensor): synthetic audio | |
| """ | |
| if TORCH_VERSION < LooseVersion("1.7"): | |
| raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") | |
| transform_dim = transform.dim() | |
| # if F x T x 2, reshape 1 x F x T x 2 | |
| if transform_dim == 3: | |
| transform = th.unsqueeze(transform, 0) | |
| if transform_dim != 4: | |
| raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") | |
| if return_polar: | |
| real = transform[..., 0] * th.cos(transform[..., 1]) | |
| imag = transform[..., 0] * th.sin(transform[..., 1]) | |
| transform = th.stack([real, imag], -1) | |
| # stft is a complex tensor of PyTorch | |
| stft = th.view_as_complex(transform) | |
| # (N) x S | |
| wav = th.istft( | |
| stft, | |
| n_fft, | |
| hop_length=frame_hop, | |
| win_length=window.shape[-1], | |
| window=window, | |
| center=center, | |
| normalized=normalized, | |
| onesided=onesided, | |
| return_complex=False, | |
| ) | |
| return wav | |
| def forward_stft( | |
| wav: th.Tensor, | |
| frame_len: int, | |
| frame_hop: int, | |
| window: str = "sqrthann", | |
| round_pow_of_two: bool = True, | |
| return_polar: bool = False, | |
| pre_emphasis: float = 0, | |
| normalized: bool = False, | |
| onesided: bool = True, | |
| center: bool = False, | |
| mode: str = "librosa", | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| STFT function implementation, equals to STFT layer | |
| Args: | |
| wav: source audio signal | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor | |
| window: window name | |
| center: center flag (similar with that in librosa.stft) | |
| round_pow_of_two: if true, choose round(#power_of_two) as the FFT size | |
| pre_emphasis: factor of preemphasis | |
| normalized: use normalized DFT kernel | |
| onesided: output onesided STFT | |
| inverse: using iDFT kernel (for iSTFT) | |
| mode: STFT mode, "kaldi" or "librosa" or "torch" | |
| Return: | |
| transform: results of STFT | |
| """ | |
| window = init_window(window, frame_len, device=wav.device) | |
| if mode == "torch": | |
| n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len | |
| return _pytorch_stft( | |
| wav, | |
| frame_len, | |
| frame_hop, | |
| n_fft=n_fft, | |
| return_polar=return_polar, | |
| window=window, | |
| normalized=normalized, | |
| onesided=onesided, | |
| center=center, | |
| eps=eps, | |
| ) | |
| else: | |
| kernel, window = init_kernel( | |
| frame_len, | |
| frame_hop, | |
| window=window, | |
| round_pow_of_two=round_pow_of_two, | |
| normalized=normalized, | |
| inverse=False, | |
| mode=mode, | |
| ) | |
| return _forward_stft( | |
| wav, | |
| kernel, | |
| window, | |
| return_polar=return_polar, | |
| frame_hop=frame_hop, | |
| pre_emphasis=pre_emphasis, | |
| onesided=onesided, | |
| center=center, | |
| eps=eps, | |
| ) | |
| def inverse_stft( | |
| transform: th.Tensor, | |
| frame_len: int, | |
| frame_hop: int, | |
| return_polar: bool = False, | |
| window: str = "sqrthann", | |
| round_pow_of_two: bool = True, | |
| normalized: bool = False, | |
| onesided: bool = True, | |
| center: bool = False, | |
| mode: str = "librosa", | |
| eps: float = EPSILON, | |
| ) -> th.Tensor: | |
| """ | |
| iSTFT function implementation, equals to iSTFT layer | |
| Args: | |
| transform: results of STFT | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| return_polar: keep same with function forward_stft(...) | |
| window: window name | |
| center: center flag (similar with that in librosa.stft) | |
| round_pow_of_two: if true, choose round(#power_of_two) as the FFT size | |
| normalized: use normalized DFT kernel | |
| onesided: output onesided STFT | |
| mode: STFT mode, "kaldi" or "librosa" or "torch" | |
| Return: | |
| wav: synthetic signals | |
| """ | |
| window = init_window(window, frame_len, device=transform.device) | |
| if mode == "torch": | |
| n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len | |
| return _pytorch_istft( | |
| transform, | |
| frame_len, | |
| frame_hop, | |
| n_fft=n_fft, | |
| return_polar=return_polar, | |
| window=window, | |
| normalized=normalized, | |
| onesided=onesided, | |
| center=center, | |
| eps=eps, | |
| ) | |
| else: | |
| kernel, window = init_kernel( | |
| frame_len, | |
| frame_hop, | |
| window, | |
| round_pow_of_two=round_pow_of_two, | |
| normalized=normalized, | |
| inverse=True, | |
| mode=mode, | |
| ) | |
| return _inverse_stft( | |
| transform, | |
| kernel, | |
| window, | |
| return_polar=return_polar, | |
| frame_hop=frame_hop, | |
| onesided=onesided, | |
| center=center, | |
| eps=eps, | |
| ) | |
| class STFTBase(nn.Module): | |
| """ | |
| Base layer for (i)STFT | |
| Args: | |
| frame_len: length of the frame | |
| frame_hop: hop size between frames | |
| window: window name | |
| center: center flag (similar with that in librosa.stft) | |
| round_pow_of_two: if true, choose round(#power_of_two) as the FFT size | |
| normalized: use normalized DFT kernel | |
| pre_emphasis: factor of preemphasis | |
| mode: STFT mode, "kaldi" or "librosa" or "torch" | |
| onesided: output onesided STFT | |
| inverse: using iDFT kernel (for iSTFT) | |
| """ | |
| def __init__( | |
| self, | |
| frame_len: int, | |
| frame_hop: int, | |
| window: str = "sqrthann", | |
| round_pow_of_two: bool = True, | |
| normalized: bool = False, | |
| pre_emphasis: float = 0, | |
| onesided: bool = True, | |
| inverse: bool = False, | |
| center: bool = False, | |
| mode: str = "librosa", | |
| ) -> None: | |
| super(STFTBase, self).__init__() | |
| if mode != "torch": | |
| K, w = init_kernel( | |
| frame_len, | |
| frame_hop, | |
| init_window(window, frame_len), | |
| round_pow_of_two=round_pow_of_two, | |
| normalized=normalized, | |
| inverse=inverse, | |
| mode=mode, | |
| ) | |
| self.K = nn.Parameter(K, requires_grad=False) | |
| self.w = nn.Parameter(w, requires_grad=False) | |
| self.num_bins = self.K.shape[0] // 4 + 1 | |
| self.pre_emphasis = pre_emphasis | |
| self.win_length = self.K.shape[2] | |
| else: | |
| self.K = None | |
| w = init_window(window, frame_len) | |
| self.w = nn.Parameter(w, requires_grad=False) | |
| fft_size = ( | |
| 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len | |
| ) | |
| self.num_bins = fft_size // 2 + 1 | |
| self.pre_emphasis = 0 | |
| self.win_length = fft_size | |
| self.frame_len = frame_len | |
| self.frame_hop = frame_hop | |
| self.window = window | |
| self.normalized = normalized | |
| self.onesided = onesided | |
| self.center = center | |
| self.mode = mode | |
| def num_frames(self, wav_len: th.Tensor) -> th.Tensor: | |
| """ | |
| Compute number of the frames | |
| """ | |
| assert th.sum(wav_len <= self.win_length) == 0 | |
| if self.center: | |
| wav_len += self.win_length | |
| return ( | |
| th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1 | |
| ) | |
| def extra_repr(self) -> str: | |
| str_repr = ( | |
| f"num_bins={self.num_bins}, win_length={self.win_length}, " | |
| + f"stride={self.frame_hop}, window={self.window}, " | |
| + f"center={self.center}, mode={self.mode}" | |
| ) | |
| if not self.onesided: | |
| str_repr += f", onesided={self.onesided}" | |
| if self.pre_emphasis > 0: | |
| str_repr += f", pre_emphasis={self.pre_emphasis}" | |
| if self.normalized: | |
| str_repr += f", normalized={self.normalized}" | |
| return str_repr | |
| class STFT(STFTBase): | |
| """ | |
| Short-time Fourier Transform as a Layer | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(STFT, self).__init__(*args, inverse=False, **kwargs) | |
| def forward( | |
| self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON | |
| ) -> th.Tensor: | |
| """ | |
| Accept (single or multiple channel) raw waveform and output magnitude and phase | |
| Args | |
| wav (Tensor) input signal, N x (C) x S | |
| Return | |
| transform (Tensor), N x (C) x F x T x 2 | |
| """ | |
| if self.mode == "torch": | |
| return _pytorch_stft( | |
| wav, | |
| self.frame_len, | |
| self.frame_hop, | |
| n_fft=(self.num_bins - 1) * 2, | |
| return_polar=return_polar, | |
| window=self.w, | |
| normalized=self.normalized, | |
| onesided=self.onesided, | |
| center=self.center, | |
| eps=eps, | |
| ) | |
| else: | |
| return _forward_stft( | |
| wav, | |
| self.K, | |
| self.w, | |
| return_polar=return_polar, | |
| frame_hop=self.frame_hop, | |
| pre_emphasis=self.pre_emphasis, | |
| onesided=self.onesided, | |
| center=self.center, | |
| eps=eps, | |
| ) | |
| class iSTFT(STFTBase): | |
| """ | |
| Inverse Short-time Fourier Transform as a Layer | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(iSTFT, self).__init__(*args, inverse=True, **kwargs) | |
| def forward( | |
| self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON | |
| ) -> th.Tensor: | |
| """ | |
| Accept phase & magnitude and output raw waveform | |
| Args | |
| transform (Tensor): STFT output, N x F x T x 2 | |
| Return | |
| s (Tensor): N x S | |
| """ | |
| if self.mode == "torch": | |
| return _pytorch_istft( | |
| transform, | |
| self.frame_len, | |
| self.frame_hop, | |
| n_fft=(self.num_bins - 1) * 2, | |
| return_polar=return_polar, | |
| window=self.w, | |
| normalized=self.normalized, | |
| onesided=self.onesided, | |
| center=self.center, | |
| eps=eps, | |
| ) | |
| else: | |
| return _inverse_stft( | |
| transform, | |
| self.K, | |
| self.w, | |
| return_polar=return_polar, | |
| frame_hop=self.frame_hop, | |
| onesided=self.onesided, | |
| center=self.center, | |
| eps=eps, | |
| ) | |