|
import math |
|
import argparse |
|
|
|
import librosa |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.parameter import Parameter |
|
|
|
|
|
class DFTBase(nn.Module): |
|
def __init__(self): |
|
r"""Base class for DFT and IDFT matrix. |
|
""" |
|
super(DFTBase, self).__init__() |
|
|
|
def dft_matrix(self, n): |
|
(x, y) = np.meshgrid(np.arange(n), np.arange(n)) |
|
omega = np.exp(-2 * np.pi * 1j / n) |
|
W = np.power(omega, x * y) |
|
return W |
|
|
|
def idft_matrix(self, n): |
|
(x, y) = np.meshgrid(np.arange(n), np.arange(n)) |
|
omega = np.exp(2 * np.pi * 1j / n) |
|
W = np.power(omega, x * y) |
|
return W |
|
|
|
|
|
class DFT(DFTBase): |
|
def __init__(self, n, norm): |
|
r"""Calculate discrete Fourier transform (DFT), inverse DFT (IDFT, |
|
right DFT (RDFT) RDFT, and inverse RDFT (IRDFT.) |
|
|
|
Args: |
|
n: fft window size |
|
norm: None | 'ortho' |
|
""" |
|
super(DFT, self).__init__() |
|
|
|
self.W = self.dft_matrix(n) |
|
self.inv_W = self.idft_matrix(n) |
|
|
|
self.W_real = torch.Tensor(np.real(self.W)) |
|
self.W_imag = torch.Tensor(np.imag(self.W)) |
|
self.inv_W_real = torch.Tensor(np.real(self.inv_W)) |
|
self.inv_W_imag = torch.Tensor(np.imag(self.inv_W)) |
|
|
|
self.n = n |
|
self.norm = norm |
|
|
|
def dft(self, x_real, x_imag): |
|
r"""Calculate DFT of a signal. |
|
|
|
Args: |
|
x_real: (n,), real part of a signal |
|
x_imag: (n,), imag part of a signal |
|
|
|
Returns: |
|
z_real: (n,), real part of output |
|
z_imag: (n,), imag part of output |
|
""" |
|
z_real = torch.matmul(x_real, self.W_real) - torch.matmul(x_imag, self.W_imag) |
|
z_imag = torch.matmul(x_imag, self.W_real) + torch.matmul(x_real, self.W_imag) |
|
|
|
|
|
if self.norm is None: |
|
pass |
|
elif self.norm == 'ortho': |
|
z_real /= math.sqrt(self.n) |
|
z_imag /= math.sqrt(self.n) |
|
|
|
return z_real, z_imag |
|
|
|
def idft(self, x_real, x_imag): |
|
r"""Calculate IDFT of a signal. |
|
|
|
Args: |
|
x_real: (n,), real part of a signal |
|
x_imag: (n,), imag part of a signal |
|
Returns: |
|
z_real: (n,), real part of output |
|
z_imag: (n,), imag part of output |
|
""" |
|
z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag) |
|
z_imag = torch.matmul(x_imag, self.inv_W_real) + torch.matmul(x_real, self.inv_W_imag) |
|
|
|
|
|
if self.norm is None: |
|
z_real /= self.n |
|
elif self.norm == 'ortho': |
|
z_real /= math.sqrt(n) |
|
z_imag /= math.sqrt(n) |
|
|
|
return z_real, z_imag |
|
|
|
def rdft(self, x_real): |
|
r"""Calculate right RDFT of signal. |
|
|
|
Args: |
|
x_real: (n,), real part of a signal |
|
x_imag: (n,), imag part of a signal |
|
|
|
Returns: |
|
z_real: (n // 2 + 1,), real part of output |
|
z_imag: (n // 2 + 1,), imag part of output |
|
""" |
|
n_rfft = self.n // 2 + 1 |
|
z_real = torch.matmul(x_real, self.W_real[..., 0 : n_rfft]) |
|
z_imag = torch.matmul(x_real, self.W_imag[..., 0 : n_rfft]) |
|
|
|
|
|
if self.norm is None: |
|
pass |
|
elif self.norm == 'ortho': |
|
z_real /= math.sqrt(self.n) |
|
z_imag /= math.sqrt(self.n) |
|
|
|
return z_real, z_imag |
|
|
|
def irdft(self, x_real, x_imag): |
|
r"""Calculate IRDFT of signal. |
|
|
|
Args: |
|
x_real: (n // 2 + 1,), real part of a signal |
|
x_imag: (n // 2 + 1,), imag part of a signal |
|
|
|
Returns: |
|
z_real: (n,), real part of output |
|
z_imag: (n,), imag part of output |
|
""" |
|
n_rfft = self.n // 2 + 1 |
|
|
|
flip_x_real = torch.flip(x_real, dims=(-1,)) |
|
flip_x_imag = torch.flip(x_imag, dims=(-1,)) |
|
|
|
|
|
x_real = torch.cat((x_real, flip_x_real[..., 1 : n_rfft - 1]), dim=-1) |
|
x_imag = torch.cat((x_imag, -1. * flip_x_imag[..., 1 : n_rfft - 1]), dim=-1) |
|
|
|
|
|
z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag) |
|
|
|
|
|
if self.norm is None: |
|
z_real /= self.n |
|
elif self.norm == 'ortho': |
|
z_real /= math.sqrt(n) |
|
|
|
return z_real |
|
|
|
|
|
class STFT(DFTBase): |
|
def __init__(self, n_fft=2048, hop_length=None, win_length=None, |
|
window='hann', center=True, pad_mode='reflect', freeze_parameters=True): |
|
r"""PyTorch implementation of STFT with Conv1d. The function has the |
|
same output as librosa.stft. |
|
|
|
Args: |
|
n_fft: int, fft window size, e.g., 2048 |
|
hop_length: int, hop length samples, e.g., 441 |
|
win_length: int, window length e.g., 2048 |
|
window: str, window function name, e.g., 'hann' |
|
center: bool |
|
pad_mode: str, e.g., 'reflect' |
|
freeze_parameters: bool, set to True to freeze all parameters. Set |
|
to False to finetune all parameters. |
|
""" |
|
super(STFT, self).__init__() |
|
|
|
assert pad_mode in ['constant', 'reflect'] |
|
|
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.window = window |
|
self.center = center |
|
self.pad_mode = pad_mode |
|
|
|
|
|
if self.win_length is None: |
|
self.win_length = n_fft |
|
|
|
|
|
if self.hop_length is None: |
|
self.hop_length = int(self.win_length // 4) |
|
|
|
fft_window = librosa.filters.get_window(window, self.win_length, fftbins=True) |
|
|
|
|
|
fft_window = librosa.util.pad_center(fft_window, size=n_fft) |
|
|
|
|
|
self.W = self.dft_matrix(n_fft) |
|
|
|
out_channels = n_fft // 2 + 1 |
|
|
|
self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels, |
|
kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1, |
|
groups=1, bias=False) |
|
|
|
self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels, |
|
kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1, |
|
groups=1, bias=False) |
|
|
|
|
|
self.conv_real.weight.data.copy_(torch.Tensor( |
|
np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]) |
|
|
|
|
|
self.conv_imag.weight.data.copy_(torch.Tensor( |
|
np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]) |
|
|
|
|
|
if freeze_parameters: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input): |
|
r"""Calculate STFT of batch of signals. |
|
|
|
Args: |
|
input: (batch_size, data_length), input signals. |
|
|
|
Returns: |
|
real: (batch_size, 1, time_steps, n_fft // 2 + 1) |
|
imag: (batch_size, 1, time_steps, n_fft // 2 + 1) |
|
""" |
|
|
|
x = input[:, None, :] |
|
|
|
if self.center: |
|
x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode) |
|
|
|
real = self.conv_real(x) |
|
imag = self.conv_imag(x) |
|
|
|
|
|
real = real[:, None, :, :].transpose(2, 3) |
|
imag = imag[:, None, :, :].transpose(2, 3) |
|
|
|
|
|
return real, imag |
|
|
|
|
|
def magphase(real, imag): |
|
r"""Calculate magnitude and phase from real and imag part of signals. |
|
|
|
Args: |
|
real: tensor, real part of signals |
|
imag: tensor, imag part of signals |
|
|
|
Returns: |
|
mag: tensor, magnitude of signals |
|
cos: tensor, cosine of phases of signals |
|
sin: tensor, sine of phases of signals |
|
""" |
|
mag = (real ** 2 + imag ** 2) ** 0.5 |
|
cos = real / torch.clamp(mag, 1e-10, np.inf) |
|
sin = imag / torch.clamp(mag, 1e-10, np.inf) |
|
|
|
return mag, cos, sin |
|
|
|
|
|
class ISTFT(DFTBase): |
|
def __init__(self, n_fft=2048, hop_length=None, win_length=None, |
|
window='hann', center=True, pad_mode='reflect', freeze_parameters=True, |
|
onnx=False, frames_num=None, device=None): |
|
"""PyTorch implementation of ISTFT with Conv1d. The function has the |
|
same output as librosa.istft. |
|
|
|
Args: |
|
n_fft: int, fft window size, e.g., 2048 |
|
hop_length: int, hop length samples, e.g., 441 |
|
win_length: int, window length e.g., 2048 |
|
window: str, window function name, e.g., 'hann' |
|
center: bool |
|
pad_mode: str, e.g., 'reflect' |
|
freeze_parameters: bool, set to True to freeze all parameters. Set |
|
to False to finetune all parameters. |
|
onnx: bool, set to True when exporting trained model to ONNX. This |
|
will replace several operations to operators supported by ONNX. |
|
frames_num: None | int, number of frames of audio clips to be |
|
inferneced. Only useable when onnx=True. |
|
device: None | str, device of ONNX. Only useable when onnx=True. |
|
""" |
|
super(ISTFT, self).__init__() |
|
|
|
assert pad_mode in ['constant', 'reflect'] |
|
|
|
if not onnx: |
|
assert frames_num is None, "When onnx=False, frames_num must be None!" |
|
assert device is None, "When onnx=False, device must be None!" |
|
|
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.window = window |
|
self.center = center |
|
self.pad_mode = pad_mode |
|
self.onnx = onnx |
|
|
|
|
|
if self.win_length is None: |
|
self.win_length = self.n_fft |
|
|
|
|
|
if self.hop_length is None: |
|
self.hop_length = int(self.win_length // 4) |
|
|
|
|
|
self.init_real_imag_conv() |
|
|
|
|
|
self.init_overlap_add_window() |
|
|
|
if self.onnx: |
|
|
|
self.init_onnx_modules(frames_num, device) |
|
|
|
if freeze_parameters: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def init_real_imag_conv(self): |
|
r"""Initialize Conv1d for calculating real and imag part of DFT. |
|
""" |
|
self.W = self.idft_matrix(self.n_fft) / self.n_fft |
|
|
|
self.conv_real = nn.Conv1d(in_channels=self.n_fft, out_channels=self.n_fft, |
|
kernel_size=1, stride=1, padding=0, dilation=1, |
|
groups=1, bias=False) |
|
|
|
self.conv_imag = nn.Conv1d(in_channels=self.n_fft, out_channels=self.n_fft, |
|
kernel_size=1, stride=1, padding=0, dilation=1, |
|
groups=1, bias=False) |
|
|
|
ifft_window = librosa.filters.get_window(self.window, self.win_length, fftbins=True) |
|
|
|
|
|
|
|
ifft_window = librosa.util.pad_center(ifft_window, size=self.n_fft) |
|
|
|
self.conv_real.weight.data = torch.Tensor( |
|
np.real(self.W * ifft_window[None, :]).T)[:, :, None] |
|
|
|
|
|
self.conv_imag.weight.data = torch.Tensor( |
|
np.imag(self.W * ifft_window[None, :]).T)[:, :, None] |
|
|
|
|
|
def init_overlap_add_window(self): |
|
r"""Initialize overlap add window for reconstruct time domain signals. |
|
""" |
|
|
|
ola_window = librosa.filters.get_window(self.window, self.win_length, fftbins=True) |
|
|
|
|
|
ola_window = librosa.util.normalize(ola_window, norm=None) ** 2 |
|
ola_window = librosa.util.pad_center(ola_window, size=self.n_fft) |
|
ola_window = torch.Tensor(ola_window) |
|
|
|
self.register_buffer('ola_window', ola_window) |
|
|
|
|
|
def init_onnx_modules(self, frames_num, device): |
|
r"""Initialize ONNX modules. |
|
|
|
Args: |
|
frames_num: int |
|
device: str | None |
|
""" |
|
|
|
|
|
|
|
self.reverse = nn.Conv1d(in_channels=self.n_fft // 2 + 1, |
|
out_channels=self.n_fft // 2 - 1, kernel_size=1, bias=False) |
|
|
|
tmp = np.zeros((self.n_fft // 2 - 1, self.n_fft // 2 + 1, 1)) |
|
tmp[:, 1 : -1, 0] = np.array(np.eye(self.n_fft // 2 - 1)[::-1]) |
|
self.reverse.weight.data = torch.Tensor(tmp) |
|
|
|
|
|
|
|
|
|
self.overlap_add = nn.ConvTranspose2d(in_channels=self.n_fft, |
|
out_channels=1, kernel_size=(self.n_fft, 1), stride=(self.hop_length, 1), bias=False) |
|
|
|
self.overlap_add.weight.data = torch.Tensor(np.eye(self.n_fft)[:, None, :, None]) |
|
|
|
|
|
if frames_num: |
|
|
|
|
|
self.ifft_window_sum = self._get_ifft_window_sum_onnx(frames_num, device) |
|
else: |
|
self.ifft_window_sum = [] |
|
|
|
def forward(self, real_stft, imag_stft, length): |
|
r"""Calculate inverse STFT. |
|
|
|
Args: |
|
real_stft: (batch_size, channels=1, time_steps, n_fft // 2 + 1) |
|
imag_stft: (batch_size, channels=1, time_steps, n_fft // 2 + 1) |
|
length: int |
|
|
|
Returns: |
|
real: (batch_size, data_length), output signals. |
|
""" |
|
assert real_stft.ndimension() == 4 and imag_stft.ndimension() == 4 |
|
batch_size, _, frames_num, _ = real_stft.shape |
|
|
|
real_stft = real_stft[:, 0, :, :].transpose(1, 2) |
|
imag_stft = imag_stft[:, 0, :, :].transpose(1, 2) |
|
|
|
|
|
|
|
if self.onnx: |
|
full_real_stft, full_imag_stft = self._get_full_stft_onnx(real_stft, imag_stft) |
|
else: |
|
full_real_stft, full_imag_stft = self._get_full_stft(real_stft, imag_stft) |
|
|
|
|
|
|
|
|
|
s_real = self.conv_real(full_real_stft) - self.conv_imag(full_imag_stft) |
|
|
|
|
|
|
|
if self.onnx: |
|
y = self._overlap_add_divide_window_sum_onnx(s_real, frames_num) |
|
else: |
|
y = self._overlap_add_divide_window_sum(s_real, frames_num) |
|
|
|
|
|
y = self._trim_edges(y, length) |
|
|
|
|
|
return y |
|
|
|
def _get_full_stft(self, real_stft, imag_stft): |
|
r"""Get full stft representation from spectrum using symmetry attribute. |
|
|
|
Args: |
|
real_stft: (batch_size, n_fft // 2 + 1, time_steps) |
|
imag_stft: (batch_size, n_fft // 2 + 1, time_steps) |
|
|
|
Returns: |
|
full_real_stft: (batch_size, n_fft, time_steps) |
|
full_imag_stft: (batch_size, n_fft, time_steps) |
|
""" |
|
full_real_stft = torch.cat((real_stft, torch.flip(real_stft[:, 1 : -1, :], dims=[1])), dim=1) |
|
full_imag_stft = torch.cat((imag_stft, - torch.flip(imag_stft[:, 1 : -1, :], dims=[1])), dim=1) |
|
|
|
return full_real_stft, full_imag_stft |
|
|
|
def _get_full_stft_onnx(self, real_stft, imag_stft): |
|
r"""Get full stft representation from spectrum using symmetry attribute |
|
for ONNX. Replace several pytorch operations in self._get_full_stft() |
|
that are not supported by ONNX. |
|
|
|
Args: |
|
real_stft: (batch_size, n_fft // 2 + 1, time_steps) |
|
imag_stft: (batch_size, n_fft // 2 + 1, time_steps) |
|
|
|
Returns: |
|
full_real_stft: (batch_size, n_fft, time_steps) |
|
full_imag_stft: (batch_size, n_fft, time_steps) |
|
""" |
|
|
|
|
|
full_real_stft = torch.cat((real_stft, self.reverse(real_stft)), dim=1) |
|
full_imag_stft = torch.cat((imag_stft, - self.reverse(imag_stft)), dim=1) |
|
|
|
return full_real_stft, full_imag_stft |
|
|
|
def _overlap_add_divide_window_sum(self, s_real, frames_num): |
|
r"""Overlap add signals in frames to reconstruct signals. |
|
|
|
Args: |
|
s_real: (batch_size, n_fft, time_steps), signals in frames |
|
frames_num: int |
|
|
|
Returns: |
|
y: (batch_size, audio_samples) |
|
""" |
|
|
|
output_samples = (s_real.shape[-1] - 1) * self.hop_length + self.win_length |
|
|
|
|
|
|
|
|
|
|
|
y = torch.nn.functional.fold(input=s_real, output_size=(1, output_samples), |
|
kernel_size=(1, self.win_length), stride=(1, self.hop_length)) |
|
|
|
|
|
y = y[:, 0, 0, :] |
|
|
|
|
|
|
|
ifft_window_sum = self._get_ifft_window(frames_num) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ifft_window_sum = torch.clamp(ifft_window_sum, 1e-11, np.inf) |
|
|
|
|
|
y = y / ifft_window_sum[None, :] |
|
|
|
|
|
return y |
|
|
|
def _get_ifft_window(self, frames_num): |
|
r"""Get overlap-add window sum to be divided. |
|
|
|
Args: |
|
frames_num: int |
|
|
|
Returns: |
|
ifft_window_sum: (audio_samlpes,), overlap-add window sum to be |
|
divided. |
|
""" |
|
|
|
output_samples = (frames_num - 1) * self.hop_length + self.win_length |
|
|
|
|
|
window_matrix = self.ola_window[None, :, None].repeat(1, 1, frames_num) |
|
|
|
|
|
ifft_window_sum = F.fold(input=window_matrix, |
|
output_size=(1, output_samples), kernel_size=(1, self.win_length), |
|
stride=(1, self.hop_length)) |
|
|
|
|
|
ifft_window_sum = ifft_window_sum.squeeze() |
|
|
|
|
|
return ifft_window_sum |
|
|
|
def _overlap_add_divide_window_sum_onnx(self, s_real, frames_num): |
|
r"""Overlap add signals in frames to reconstruct signals for ONNX. |
|
Replace several pytorch operations in |
|
self._overlap_add_divide_window_sum() that are not supported by ONNX. |
|
|
|
Args: |
|
s_real: (batch_size, n_fft, time_steps), signals in frames |
|
frames_num: int |
|
|
|
Returns: |
|
y: (batch_size, audio_samples) |
|
""" |
|
|
|
s_real = s_real[..., None] |
|
|
|
|
|
|
|
|
|
y = self.overlap_add(s_real)[:, 0, :, 0] |
|
|
|
|
|
if len(self.ifft_window_sum) != y.shape[1]: |
|
device = s_real.device |
|
|
|
self.ifft_window_sum = self._get_ifft_window_sum_onnx(frames_num, device) |
|
|
|
|
|
|
|
|
|
ifft_window_sum = torch.clamp(self.ifft_window_sum, 1e-11, np.inf) |
|
|
|
|
|
y = y / ifft_window_sum[None, :] |
|
|
|
|
|
return y |
|
|
|
def _get_ifft_window_sum_onnx(self, frames_num, device): |
|
r"""Pre-calculate overlap-add window sum for reconstructing signals when |
|
using ONNX. |
|
|
|
Args: |
|
frames_num: int |
|
device: str | None |
|
|
|
Returns: |
|
ifft_window_sum: (audio_samples,) |
|
""" |
|
|
|
ifft_window_sum = librosa.filters.window_sumsquare(window=self.window, |
|
n_frames=frames_num, win_length=self.win_length, n_fft=self.n_fft, |
|
hop_length=self.hop_length) |
|
|
|
|
|
ifft_window_sum = torch.Tensor(ifft_window_sum) |
|
|
|
if device: |
|
ifft_window_sum = ifft_window_sum.to(device) |
|
|
|
return ifft_window_sum |
|
|
|
def _trim_edges(self, y, length): |
|
r"""Trim audio. |
|
|
|
Args: |
|
y: (audio_samples,) |
|
length: int |
|
|
|
Returns: |
|
(trimmed_audio_samples,) |
|
""" |
|
|
|
if length is None: |
|
if self.center: |
|
y = y[:, self.n_fft // 2 : -self.n_fft // 2] |
|
else: |
|
if self.center: |
|
start = self.n_fft // 2 |
|
else: |
|
start = 0 |
|
|
|
y = y[:, start : start + length] |
|
|
|
return y |
|
|
|
|
|
class Spectrogram(nn.Module): |
|
def __init__(self, n_fft=2048, hop_length=None, win_length=None, |
|
window='hann', center=True, pad_mode='reflect', power=2.0, |
|
freeze_parameters=True): |
|
r"""Calculate spectrogram using pytorch. The STFT is implemented with |
|
Conv1d. The function has the same output of librosa.stft |
|
""" |
|
super(Spectrogram, self).__init__() |
|
|
|
self.power = power |
|
|
|
self.stft = STFT(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, |
|
pad_mode=pad_mode, freeze_parameters=True) |
|
|
|
def forward(self, input): |
|
r"""Calculate spectrogram of input signals. |
|
Args: |
|
input: (batch_size, data_length) |
|
|
|
Returns: |
|
spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1) |
|
""" |
|
|
|
(real, imag) = self.stft.forward(input) |
|
|
|
|
|
spectrogram = real ** 2 + imag ** 2 |
|
|
|
if self.power == 2.0: |
|
pass |
|
else: |
|
spectrogram = spectrogram ** (self.power / 2.0) |
|
|
|
return spectrogram |
|
|
|
|
|
class LogmelFilterBank(nn.Module): |
|
def __init__(self, sr=22050, n_fft=2048, n_mels=64, fmin=0.0, fmax=None, |
|
is_log=True, ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True): |
|
r"""Calculate logmel spectrogram using pytorch. The mel filter bank is |
|
the pytorch implementation of as librosa.filters.mel |
|
""" |
|
super(LogmelFilterBank, self).__init__() |
|
|
|
self.is_log = is_log |
|
self.ref = ref |
|
self.amin = amin |
|
self.top_db = top_db |
|
if fmax == None: |
|
fmax = sr//2 |
|
|
|
self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, |
|
fmin=fmin, fmax=fmax).T |
|
|
|
|
|
self.melW = nn.Parameter(torch.Tensor(self.melW).contiguous()) |
|
|
|
if freeze_parameters: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input): |
|
r"""Calculate (log) mel spectrogram from spectrogram. |
|
|
|
Args: |
|
input: (*, n_fft), spectrogram |
|
|
|
Returns: |
|
output: (*, mel_bins), (log) mel spectrogram |
|
""" |
|
|
|
|
|
mel_spectrogram = torch.matmul(input, self.melW) |
|
|
|
|
|
|
|
if self.is_log: |
|
output = self.power_to_db(mel_spectrogram) |
|
else: |
|
output = mel_spectrogram |
|
|
|
return output |
|
|
|
|
|
def power_to_db(self, input): |
|
r"""Power to db, this function is the pytorch implementation of |
|
librosa.power_to_lb |
|
""" |
|
ref_value = self.ref |
|
log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf)) |
|
log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value)) |
|
|
|
if self.top_db is not None: |
|
if self.top_db < 0: |
|
raise librosa.util.exceptions.ParameterError('top_db must be non-negative') |
|
log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf) |
|
|
|
return log_spec |
|
|
|
|
|
class Enframe(nn.Module): |
|
def __init__(self, frame_length=2048, hop_length=512): |
|
r"""Enframe a time sequence. This function is the pytorch implementation |
|
of librosa.util.frame |
|
""" |
|
super(Enframe, self).__init__() |
|
|
|
self.enframe_conv = nn.Conv1d(in_channels=1, out_channels=frame_length, |
|
kernel_size=frame_length, stride=hop_length, |
|
padding=0, bias=False) |
|
|
|
self.enframe_conv.weight.data = torch.Tensor(torch.eye(frame_length)[:, None, :]) |
|
self.enframe_conv.weight.requires_grad = False |
|
|
|
def forward(self, input): |
|
r"""Enframe signals into frames. |
|
Args: |
|
input: (batch_size, samples) |
|
|
|
Returns: |
|
output: (batch_size, window_length, frames_num) |
|
""" |
|
output = self.enframe_conv(input[:, None, :]) |
|
return output |
|
|
|
|
|
def power_to_db(self, input): |
|
r"""Power to db, this function is the pytorch implementation of |
|
librosa.power_to_lb. |
|
""" |
|
ref_value = self.ref |
|
log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf)) |
|
log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value)) |
|
|
|
if self.top_db is not None: |
|
if self.top_db < 0: |
|
raise librosa.util.exceptions.ParameterError('top_db must be non-negative') |
|
log_spec = torch.clamp(log_spec, min=log_spec.max() - self.top_db, max=np.inf) |
|
|
|
return log_spec |
|
|
|
|
|
class Scalar(nn.Module): |
|
def __init__(self, scalar, freeze_parameters): |
|
super(Scalar, self).__init__() |
|
|
|
self.scalar_mean = Parameter(torch.Tensor(scalar['mean'])) |
|
self.scalar_std = Parameter(torch.Tensor(scalar['std'])) |
|
|
|
if freeze_parameters: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input): |
|
return (input - self.scalar_mean) / self.scalar_std |
|
|
|
|
|
def debug(select, device): |
|
"""Compare numpy + librosa and torchlibrosa results. For debug. |
|
|
|
Args: |
|
select: 'dft' | 'logmel' |
|
device: 'cpu' | 'cuda' |
|
""" |
|
|
|
if select == 'dft': |
|
n = 10 |
|
norm = None |
|
np.random.seed(0) |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, n) |
|
pt_data = torch.Tensor(np_data) |
|
|
|
|
|
np_fft = np.fft.fft(np_data, norm=norm) |
|
np_ifft = np.fft.ifft(np_fft, norm=norm) |
|
np_rfft = np.fft.rfft(np_data, norm=norm) |
|
np_irfft = np.fft.ifft(np_rfft, norm=norm) |
|
|
|
|
|
obj = DFT(n, norm) |
|
pt_dft = obj.dft(pt_data, torch.zeros_like(pt_data)) |
|
pt_idft = obj.idft(pt_dft[0], pt_dft[1]) |
|
pt_rdft = obj.rdft(pt_data) |
|
pt_irdft = obj.irdft(pt_rdft[0], pt_rdft[1]) |
|
|
|
print('Comparing librosa and pytorch implementation of DFT. All numbers ' |
|
'below should be close to 0.') |
|
print(np.mean((np.abs(np.real(np_fft) - pt_dft[0].cpu().numpy())))) |
|
print(np.mean((np.abs(np.imag(np_fft) - pt_dft[1].cpu().numpy())))) |
|
|
|
print(np.mean((np.abs(np.real(np_ifft) - pt_idft[0].cpu().numpy())))) |
|
print(np.mean((np.abs(np.imag(np_ifft) - pt_idft[1].cpu().numpy())))) |
|
|
|
print(np.mean((np.abs(np.real(np_rfft) - pt_rdft[0].cpu().numpy())))) |
|
print(np.mean((np.abs(np.imag(np_rfft) - pt_rdft[1].cpu().numpy())))) |
|
|
|
print(np.mean(np.abs(np_data - pt_irdft.cpu().numpy()))) |
|
|
|
elif select == 'stft': |
|
device = torch.device(device) |
|
np.random.seed(0) |
|
|
|
|
|
sample_rate = 22050 |
|
data_length = sample_rate * 1 |
|
n_fft = 2048 |
|
hop_length = 512 |
|
win_length = 2048 |
|
window = 'hann' |
|
center = True |
|
pad_mode = 'reflect' |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, data_length) |
|
pt_data = torch.Tensor(np_data).to(device) |
|
|
|
|
|
np_stft_matrix = librosa.stft(y=np_data, n_fft=n_fft, |
|
hop_length=hop_length, window=window, center=center).T |
|
|
|
|
|
pt_stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, pad_mode=pad_mode, |
|
freeze_parameters=True) |
|
|
|
pt_stft_extractor.to(device) |
|
|
|
(pt_stft_real, pt_stft_imag) = pt_stft_extractor.forward(pt_data[None, :]) |
|
|
|
print('Comparing librosa and pytorch implementation of STFT & ISTFT. \ |
|
All numbers below should be close to 0.') |
|
print(np.mean(np.abs(np.real(np_stft_matrix) - pt_stft_real.data.cpu().numpy()[0, 0]))) |
|
print(np.mean(np.abs(np.imag(np_stft_matrix) - pt_stft_imag.data.cpu().numpy()[0, 0]))) |
|
|
|
|
|
np_istft_s = librosa.istft(stft_matrix=np_stft_matrix.T, |
|
hop_length=hop_length, window=window, center=center, length=data_length) |
|
|
|
|
|
pt_istft_extractor = ISTFT(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, pad_mode=pad_mode, |
|
freeze_parameters=True) |
|
pt_istft_extractor.to(device) |
|
|
|
|
|
pt_istft_s = pt_istft_extractor.forward(pt_stft_real, pt_stft_imag, data_length)[0, :] |
|
|
|
|
|
(pt_stft_mag, cos, sin) = magphase(pt_stft_real, pt_stft_imag) |
|
pt_istft_s2 = pt_istft_extractor.forward(pt_stft_mag * cos, pt_stft_mag * sin, data_length)[0, :] |
|
|
|
print(np.mean(np.abs(np_istft_s - pt_istft_s.data.cpu().numpy()))) |
|
print(np.mean(np.abs(np_data - pt_istft_s.data.cpu().numpy()))) |
|
print(np.mean(np.abs(np_data - pt_istft_s2.data.cpu().numpy()))) |
|
|
|
elif select == 'logmel': |
|
dtype = np.complex64 |
|
device = torch.device(device) |
|
np.random.seed(0) |
|
|
|
|
|
sample_rate = 22050 |
|
data_length = sample_rate * 1 |
|
n_fft = 2048 |
|
hop_length = 512 |
|
win_length = 2048 |
|
window = 'hann' |
|
center = True |
|
pad_mode = 'reflect' |
|
|
|
|
|
n_mels = 128 |
|
fmin = 0. |
|
fmax = sample_rate / 2.0 |
|
|
|
|
|
ref = 1.0 |
|
amin = 1e-10 |
|
top_db = 80.0 |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, data_length) |
|
pt_data = torch.Tensor(np_data).to(device) |
|
|
|
print('Comparing librosa and pytorch implementation of logmel ' |
|
'spectrogram. All numbers below should be close to 0.') |
|
|
|
|
|
np_stft_matrix = librosa.stft(y=np_data, n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, dtype=dtype, |
|
pad_mode=pad_mode) |
|
|
|
np_pad = np.pad(np_data, int(n_fft // 2), mode=pad_mode) |
|
|
|
np_melW = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, |
|
fmin=fmin, fmax=fmax).T |
|
|
|
np_mel_spectrogram = np.dot(np.abs(np_stft_matrix.T) ** 2, np_melW) |
|
|
|
np_logmel_spectrogram = librosa.power_to_db( |
|
np_mel_spectrogram, ref=ref, amin=amin, top_db=top_db) |
|
|
|
|
|
stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, pad_mode=pad_mode, |
|
freeze_parameters=True) |
|
|
|
logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=n_fft, |
|
n_mels=n_mels, fmin=fmin, fmax=fmax, ref=ref, amin=amin, |
|
top_db=top_db, freeze_parameters=True) |
|
|
|
stft_extractor.to(device) |
|
logmel_extractor.to(device) |
|
|
|
pt_pad = F.pad(pt_data[None, None, :], pad=(n_fft // 2, n_fft // 2), mode=pad_mode)[0, 0] |
|
print(np.mean(np.abs(np_pad - pt_pad.cpu().numpy()))) |
|
|
|
pt_stft_matrix_real = stft_extractor.conv_real(pt_pad[None, None, :])[0] |
|
pt_stft_matrix_imag = stft_extractor.conv_imag(pt_pad[None, None, :])[0] |
|
print(np.mean(np.abs(np.real(np_stft_matrix) - pt_stft_matrix_real.data.cpu().numpy()))) |
|
print(np.mean(np.abs(np.imag(np_stft_matrix) - pt_stft_matrix_imag.data.cpu().numpy()))) |
|
|
|
|
|
spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, pad_mode=pad_mode, |
|
freeze_parameters=True) |
|
|
|
spectrogram_extractor.to(device) |
|
|
|
pt_spectrogram = spectrogram_extractor.forward(pt_data[None, :]) |
|
pt_mel_spectrogram = torch.matmul(pt_spectrogram, logmel_extractor.melW) |
|
print(np.mean(np.abs(np_mel_spectrogram - pt_mel_spectrogram.data.cpu().numpy()[0, 0]))) |
|
|
|
|
|
pt_logmel_spectrogram = logmel_extractor.forward(pt_spectrogram) |
|
print(np.mean(np.abs(np_logmel_spectrogram - pt_logmel_spectrogram[0, 0].data.cpu().numpy()))) |
|
|
|
elif select == 'enframe': |
|
device = torch.device(device) |
|
np.random.seed(0) |
|
|
|
|
|
sample_rate = 22050 |
|
data_length = sample_rate * 1 |
|
hop_length = 512 |
|
win_length = 2048 |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, data_length) |
|
pt_data = torch.Tensor(np_data).to(device) |
|
|
|
print('Comparing librosa and pytorch implementation of ' |
|
'librosa.util.frame. All numbers below should be close to 0.') |
|
|
|
|
|
np_frames = librosa.util.frame(np_data, frame_length=win_length, |
|
hop_length=hop_length) |
|
|
|
|
|
pt_frame_extractor = Enframe(frame_length=win_length, hop_length=hop_length) |
|
pt_frame_extractor.to(device) |
|
|
|
pt_frames = pt_frame_extractor(pt_data[None, :]) |
|
print(np.mean(np.abs(np_frames - pt_frames.data.cpu().numpy()))) |
|
|
|
elif select == 'default': |
|
device = torch.device(device) |
|
np.random.seed(0) |
|
|
|
|
|
sample_rate = 22050 |
|
data_length = sample_rate * 1 |
|
hop_length = 512 |
|
win_length = 2048 |
|
|
|
|
|
n_mels = 128 |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, data_length) |
|
pt_data = torch.Tensor(np_data).to(device) |
|
|
|
feature_extractor = nn.Sequential( |
|
Spectrogram( |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
), LogmelFilterBank( |
|
sr=sample_rate, |
|
n_mels=n_mels, |
|
is_log=False, |
|
)) |
|
|
|
feature_extractor.to(device) |
|
|
|
print( |
|
'Comparing default mel spectrogram from librosa to the pytorch implementation.' |
|
) |
|
|
|
|
|
np_melspect = librosa.feature.melspectrogram(np_data, |
|
hop_length=hop_length, |
|
sr=sample_rate, |
|
win_length=win_length, |
|
n_mels=n_mels).T |
|
|
|
pt_melspect = feature_extractor(pt_data[None, :]).squeeze() |
|
passed = np.allclose(pt_melspect.data.to('cpu').numpy(), np_melspect) |
|
print(f"Passed? {passed}") |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser(description='') |
|
parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda']) |
|
args = parser.parse_args() |
|
|
|
device = args.device |
|
norm = None |
|
np.random.seed(0) |
|
|
|
|
|
sample_rate = 22050 |
|
data_length = sample_rate * 1 |
|
n_fft = 2048 |
|
hop_length = 512 |
|
win_length = 2048 |
|
window = 'hann' |
|
center = True |
|
pad_mode = 'reflect' |
|
|
|
|
|
n_mels = 128 |
|
fmin = 0. |
|
fmax = sample_rate / 2.0 |
|
|
|
|
|
ref = 1.0 |
|
amin = 1e-10 |
|
top_db = 80.0 |
|
|
|
|
|
np_data = np.random.uniform(-1, 1, data_length) |
|
pt_data = torch.Tensor(np_data).to(device) |
|
|
|
|
|
spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length, |
|
win_length=win_length, window=window, center=center, pad_mode=pad_mode, |
|
freeze_parameters=True) |
|
|
|
logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=n_fft, |
|
n_mels=n_mels, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, |
|
freeze_parameters=True) |
|
|
|
spectrogram_extractor.to(device) |
|
logmel_extractor.to(device) |
|
|
|
|
|
pt_spectrogram = spectrogram_extractor.forward(pt_data[None, :]) |
|
|
|
|
|
pt_logmel_spectrogram = logmel_extractor.forward(pt_spectrogram) |
|
|
|
|
|
if True: |
|
debug(select='dft', device=device) |
|
debug(select='stft', device=device) |
|
debug(select='logmel', device=device) |
|
debug(select='enframe', device=device) |
|
|
|
try: |
|
debug(select='default', device=device) |
|
except: |
|
raise Exception('Torchlibrosa does support librosa>=0.6.0, for \ |
|
comparison with librosa, please use librosa>=0.7.0!') |
|
|