akhaliq3
spaces demo
5019931
from torchlibrosa.stft import STFT, ISTFT, magphase
import torch
import torch.nn as nn
import numpy as np
from tools.pytorch.modules.pqmf import PQMF
class FDomainHelper(nn.Module):
def __init__(
self,
window_size=2048,
hop_size=441,
center=True,
pad_mode='reflect',
window='hann',
freeze_parameters=True,
subband=None,
root="/Users/admin/Documents/projects/",
):
super(FDomainHelper, self).__init__()
self.subband = subband
if self.subband is None:
self.stft = STFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=freeze_parameters,
)
self.istft = ISTFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=freeze_parameters,
)
else:
self.stft = STFT(
n_fft=window_size // self.subband,
hop_length=hop_size // self.subband,
win_length=window_size // self.subband,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=freeze_parameters,
)
self.istft = ISTFT(
n_fft=window_size // self.subband,
hop_length=hop_size // self.subband,
win_length=window_size // self.subband,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=freeze_parameters,
)
if subband is not None and root is not None:
self.qmf = PQMF(subband, 64, root)
def complex_spectrogram(self, input, eps=0.0):
# [batchsize, samples]
# return [batchsize, 2, t-steps, f-bins]
real, imag = self.stft(input)
return torch.cat([real, imag], dim=1)
def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
# [batchsize, 2[real,imag], t-steps, f-bins]
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
return wav
def spectrogram(self, input, eps=0.0):
(real, imag) = self.stft(input.float())
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
def spectrogram_phase(self, input, eps=0.0):
(real, imag) = self.stft(input.float())
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
cos = real / mag
sin = imag / mag
return mag, cos, sin
def wav_to_spectrogram_phase(self, input, eps=1e-8):
"""Waveform to spectrogram.
Args:
input: (batch_size, channels_num, segment_samples)
Outputs:
output: (batch_size, channels_num, time_steps, freq_bins)
"""
sp_list = []
cos_list = []
sin_list = []
channels_num = input.shape[1]
for channel in range(channels_num):
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
sp_list.append(mag)
cos_list.append(cos)
sin_list.append(sin)
sps = torch.cat(sp_list, dim=1)
coss = torch.cat(cos_list, dim=1)
sins = torch.cat(sin_list, dim=1)
return sps, coss, sins
def spectrogram_phase_to_wav(self, sps, coss, sins, length):
channels_num = sps.size()[1]
res = []
for i in range(channels_num):
res.append(
self.istft(
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
length,
)
)
res[-1] = res[-1].unsqueeze(1)
return torch.cat(res, dim=1)
def wav_to_spectrogram(self, input, eps=1e-8):
"""Waveform to spectrogram.
Args:
input: (batch_size,channels_num, segment_samples)
Outputs:
output: (batch_size, channels_num, time_steps, freq_bins)
"""
sp_list = []
channels_num = input.shape[1]
for channel in range(channels_num):
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
output = torch.cat(sp_list, dim=1)
return output
def spectrogram_to_wav(self, input, spectrogram, length=None):
"""Spectrogram to waveform.
Args:
input: (batch_size, segment_samples, channels_num)
spectrogram: (batch_size, channels_num, time_steps, freq_bins)
Outputs:
output: (batch_size, segment_samples, channels_num)
"""
channels_num = input.shape[1]
wav_list = []
for channel in range(channels_num):
(real, imag) = self.stft(input[:, channel, :])
(_, cos, sin) = magphase(real, imag)
wav_list.append(
self.istft(
spectrogram[:, channel : channel + 1, :, :] * cos,
spectrogram[:, channel : channel + 1, :, :] * sin,
length,
)
)
output = torch.stack(wav_list, dim=1)
return output
# todo the following code is not bug free!
def wav_to_complex_spectrogram(self, input, eps=0.0):
# [batchsize , channels, samples]
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
res = []
channels_num = input.shape[1]
for channel in range(channels_num):
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
return torch.cat(res, dim=1)
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
# return [batchsize, channels, samples]
channels = input.size()[1] // 2
wavs = []
for i in range(channels):
wavs.append(
self.reverse_complex_spectrogram(
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
)
)
wavs[-1] = wavs[-1].unsqueeze(1)
return torch.cat(wavs, dim=1)
def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
# [batchsize, channels, samples]
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
subspec = self.wav_to_complex_spectrogram(subwav)
return subspec
def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
# [batchsize, channels, samples]
subwav = self.complex_spectrogram_to_wav(input)
data = self.qmf.synthesis(subwav)
return data
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
"""
:param input:
:param eps:
:return:
loss = torch.nn.L1Loss()
model = FDomainHelper(subband=4)
data = torch.randn((3,1, 44100*3))
sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data)
wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
print(loss(data,wav))
print(torch.max(torch.abs(data-wav)))
"""
# [batchsize, channels, samples]
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
return sps, coss, sins
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
# [batchsize, channels, samples]
subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length)
data = self.qmf.synthesis(subwav)
return data
if __name__ == "__main__":
# from thop import profile
# from thop import clever_format
# from tools.file.wav import *
# import time
#
# wav = torch.randn((1,2,44100))
# model = FDomainHelper()
from tools.file.wav import *
loss = torch.nn.L1Loss()
model = FDomainHelper()
data = torch.randn((3, 1, 44100 * 5))
sps = model.wav_to_complex_spectrogram(data)
print(sps.size())
wav = model.complex_spectrogram_to_wav(sps, 44100 * 5)
print(loss(data, wav))
print(torch.max(torch.abs(data - wav)))