from torchlibrosa.stft import STFT, ISTFT, magphase import torch import torch.nn as nn import numpy as np from tools.pytorch.modules.pqmf import PQMF class FDomainHelper(nn.Module): def __init__( self, window_size=2048, hop_size=441, center=True, pad_mode='reflect', window='hann', freeze_parameters=True, subband=None, root="/Users/admin/Documents/projects/", ): super(FDomainHelper, self).__init__() self.subband = subband if self.subband is None: self.stft = STFT( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=freeze_parameters, ) self.istft = ISTFT( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=freeze_parameters, ) else: self.stft = STFT( n_fft=window_size // self.subband, hop_length=hop_size // self.subband, win_length=window_size // self.subband, window=window, center=center, pad_mode=pad_mode, freeze_parameters=freeze_parameters, ) self.istft = ISTFT( n_fft=window_size // self.subband, hop_length=hop_size // self.subband, win_length=window_size // self.subband, window=window, center=center, pad_mode=pad_mode, freeze_parameters=freeze_parameters, ) if subband is not None and root is not None: self.qmf = PQMF(subband, 64, root) def complex_spectrogram(self, input, eps=0.0): # [batchsize, samples] # return [batchsize, 2, t-steps, f-bins] real, imag = self.stft(input) return torch.cat([real, imag], dim=1) def reverse_complex_spectrogram(self, input, eps=0.0, length=None): # [batchsize, 2[real,imag], t-steps, f-bins] wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length) return wav def spectrogram(self, input, eps=0.0): (real, imag) = self.stft(input.float()) return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 def spectrogram_phase(self, input, eps=0.0): (real, imag) = self.stft(input.float()) mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 cos = real / mag sin = imag / mag return mag, cos, sin def wav_to_spectrogram_phase(self, input, eps=1e-8): """Waveform to spectrogram. Args: input: (batch_size, channels_num, segment_samples) Outputs: output: (batch_size, channels_num, time_steps, freq_bins) """ sp_list = [] cos_list = [] sin_list = [] channels_num = input.shape[1] for channel in range(channels_num): mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) sp_list.append(mag) cos_list.append(cos) sin_list.append(sin) sps = torch.cat(sp_list, dim=1) coss = torch.cat(cos_list, dim=1) sins = torch.cat(sin_list, dim=1) return sps, coss, sins def spectrogram_phase_to_wav(self, sps, coss, sins, length): channels_num = sps.size()[1] res = [] for i in range(channels_num): res.append( self.istft( sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...], sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...], length, ) ) res[-1] = res[-1].unsqueeze(1) return torch.cat(res, dim=1) def wav_to_spectrogram(self, input, eps=1e-8): """Waveform to spectrogram. Args: input: (batch_size,channels_num, segment_samples) Outputs: output: (batch_size, channels_num, time_steps, freq_bins) """ sp_list = [] channels_num = input.shape[1] for channel in range(channels_num): sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) output = torch.cat(sp_list, dim=1) return output def spectrogram_to_wav(self, input, spectrogram, length=None): """Spectrogram to waveform. Args: input: (batch_size, segment_samples, channels_num) spectrogram: (batch_size, channels_num, time_steps, freq_bins) Outputs: output: (batch_size, segment_samples, channels_num) """ channels_num = input.shape[1] wav_list = [] for channel in range(channels_num): (real, imag) = self.stft(input[:, channel, :]) (_, cos, sin) = magphase(real, imag) wav_list.append( self.istft( spectrogram[:, channel : channel + 1, :, :] * cos, spectrogram[:, channel : channel + 1, :, :] * sin, length, ) ) output = torch.stack(wav_list, dim=1) return output # todo the following code is not bug free! def wav_to_complex_spectrogram(self, input, eps=0.0): # [batchsize , channels, samples] # [batchsize, 2[real,imag]*channels, t-steps, f-bins] res = [] channels_num = input.shape[1] for channel in range(channels_num): res.append(self.complex_spectrogram(input[:, channel, :], eps=eps)) return torch.cat(res, dim=1) def complex_spectrogram_to_wav(self, input, eps=0.0, length=None): # [batchsize, 2[real,imag]*channels, t-steps, f-bins] # return [batchsize, channels, samples] channels = input.size()[1] // 2 wavs = [] for i in range(channels): wavs.append( self.reverse_complex_spectrogram( input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length ) ) wavs[-1] = wavs[-1].unsqueeze(1) return torch.cat(wavs, dim=1) def wav_to_complex_subband_spectrogram(self, input, eps=0.0): # [batchsize, channels, samples] # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] subspec = self.wav_to_complex_spectrogram(subwav) return subspec def complex_subband_spectrogram_to_wav(self, input, eps=0.0): # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] # [batchsize, channels, samples] subwav = self.complex_spectrogram_to_wav(input) data = self.qmf.synthesis(subwav) return data def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8): """ :param input: :param eps: :return: loss = torch.nn.L1Loss() model = FDomainHelper(subband=4) data = torch.randn((3,1, 44100*3)) sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data) wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4) print(loss(data,wav)) print(torch.max(torch.abs(data-wav))) """ # [batchsize, channels, samples] # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps) return sps, coss, sins def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0): # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] # [batchsize, channels, samples] subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length) data = self.qmf.synthesis(subwav) return data if __name__ == "__main__": # from thop import profile # from thop import clever_format # from tools.file.wav import * # import time # # wav = torch.randn((1,2,44100)) # model = FDomainHelper() from tools.file.wav import * loss = torch.nn.L1Loss() model = FDomainHelper() data = torch.randn((3, 1, 44100 * 5)) sps = model.wav_to_complex_spectrogram(data) print(sps.size()) wav = model.complex_spectrogram_to_wav(sps, 44100 * 5) print(loss(data, wav)) print(torch.max(torch.abs(data - wav)))