jhtonyKoo's picture
Upload 61 files
2777fde
raw history blame
No virus
11.1 kB
""" Front-end: processing raw data input """
import torch
import torch.nn as nn
import torchaudio.functional as ta_F
import torchaudio
class FrontEnd(nn.Module):
def __init__(self, channel='stereo', \
n_fft=2048, \
hop_length=None, \
win_length=None, \
window="hann", \
device=torch.device("cpu")):
super(FrontEnd, self).__init__()
self.channel = channel
self.n_fft = n_fft
self.hop_length = n_fft//4 if hop_length==None else hop_length
self.win_length = n_fft if win_length==None else win_length
if window=="hann":
self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(device)
elif window=="hamming":
self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(device)
def forward(self, input, mode):
# front-end function which channel-wise combines all demanded features
# input shape : batch x channel x raw waveform
# output shape : batch x channel x frequency x time
front_output_list = []
for cur_mode in mode:
# Real & Imaginary
if cur_mode=="cplx":
if self.channel=="mono":
output = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
elif self.channel=="stereo":
output_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
output_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
output = torch.cat((output_l, output_r), axis=-1)
if input.shape[2] % round(self.n_fft/4) == 0:
output = output[:, :, :-1]
if self.n_fft % 2 == 0:
output = output[:, :-1]
front_output_list.append(output.permute(0, 3, 1, 2))
# Magnitude & Phase
elif cur_mode=="mag":
if self.channel=="mono":
cur_cplx = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
output = self.mag(cur_cplx).unsqueeze(-1)[..., 0:1]
elif self.channel=="stereo":
cplx_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
cplx_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
mag_l = self.mag(cplx_l).unsqueeze(-1)
mag_r = self.mag(cplx_r).unsqueeze(-1)
output = torch.cat((mag_l, mag_r), axis=-1)
if input.shape[-1] % round(self.n_fft/4) == 0:
output = output[:, :, :-1]
if self.n_fft % 2 == 0: # discard highest frequency
output = output[:, 1:]
front_output_list.append(output.permute(0, 3, 1, 2))
# combine all demanded features
if not front_output_list:
raise NameError("NameError at FrontEnd: check using features for front-end")
elif len(mode)!=1:
for i, cur_output in enumerate(front_output_list):
if i==0:
front_output = cur_output
else:
front_output = torch.cat((front_output, cur_output), axis=1)
else:
front_output = front_output_list[0]
return front_output
def mag(self, cplx_input, eps=1e-07):
mag_summed = cplx_input.pow(2.).sum(-1) + eps
return mag_summed.pow(0.5)
class BackEnd(nn.Module):
def __init__(self, channel='stereo', \
n_fft=2048, \
hop_length=None, \
win_length=None, \
window="hann", \
eps=1e-07, \
orig_freq=44100, \
new_freq=16000, \
device=torch.device("cpu")):
super(BackEnd, self).__init__()
self.device = device
self.channel = channel
self.n_fft = n_fft
self.hop_length = n_fft//4 if hop_length==None else hop_length
self.win_length = n_fft if win_length==None else win_length
self.eps = eps
if window=="hann":
self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(self.device)
elif window=="hamming":
self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(self.device)
self.resample_func_8k = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=8000).to(self.device)
self.resample_func = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq).to(self.device)
def magphase_to_cplx(self, magphase_spec):
real = magphase_spec[..., 0] * torch.cos(magphase_spec[..., 1])
imaginary = magphase_spec[..., 0] * torch.sin(magphase_spec[..., 1])
return torch.cat((real.unsqueeze(-1), imaginary.unsqueeze(-1)), dim=-1)
def forward(self, input, phase, mode):
# back-end function which convert output spectrograms into waveform
# input shape : batch x channel x frequency x time
# output shape : batch x channel x raw waveform
# convert to shape : batch x frequency x time x channel
input = input.permute(0, 2, 3, 1)
# pad highest frequency
pad = torch.zeros((input.shape[0], 1, input.shape[2], input.shape[3])).to(self.device)
input = torch.cat((pad, input), dim=1)
back_output_list = []
channel_count = 0
for i, cur_mode in enumerate(mode):
# Real & Imaginary
if cur_mode=="cplx":
if self.channel=="mono":
output = ta_F.istft(input[...,channel_count:channel_count+2], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
channel_count += 2
elif self.channel=="stereo":
cplx_spec = torch.cat([input[...,channel_count:channel_count+2], input[...,channel_count+2:channel_count+4]], dim=0)
output_wav = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
channel_count += 4
back_output_list.append(output)
# Magnitude & Phase
elif cur_mode=="mag_phase" or cur_mode=="mag":
if self.channel=="mono":
if cur_mode=="mag":
input_spec = torch.cat((input[...,channel_count:channel_count+1], phase), axis=-1)
channel_count += 1
else:
input_spec = input[...,channel_count:channel_count+2]
channel_count += 2
cplx_spec = self.magphase_to_cplx(input_spec)
output = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
elif self.channel=="stereo":
if cur_mode=="mag":
input_spec_l = torch.cat((input[...,channel_count:channel_count+1], phase[...,0:1]), axis=-1)
input_spec_r = torch.cat((input[...,channel_count+1:channel_count+2], phase[...,1:2]), axis=-1)
channel_count += 2
else:
input_spec_l = input[...,channel_count:channel_count+2]
input_spec_r = input[...,channel_count+2:channel_count+4]
channel_count += 4
cplx_spec_l = self.magphase_to_cplx(input_spec_l)
cplx_spec_r = self.magphase_to_cplx(input_spec_r)
cplx_spec = torch.cat([cplx_spec_l, cplx_spec_r], dim=0)
output_wav = torch.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
channel_count += 4
back_output_list.append(output)
elif cur_mode=="griff":
if self.channel=="mono":
output = self.griffin_lim(input.squeeze(-1), input.device).unsqueeze(1)
# output = self.griff(input.permute(0, 3, 1, 2))
else:
output_l = self.griffin_lim(input[..., 0], input.device).unsqueeze(1)
output_r = self.griffin_lim(input[..., 1], input.device).unsqueeze(1)
output = torch.cat((output_l, output_r), axis=1)
back_output_list.append(output)
# combine all demanded feature outputs
if not back_output_list:
raise NameError("NameError at BackEnd: check using features for back-end")
elif len(mode)!=1:
for i, cur_output in enumerate(back_output_list):
if i==0:
back_output = cur_output
else:
back_output = torch.cat((back_output, cur_output), axis=1)
else:
back_output = back_output_list[0]
return back_output
def griffin_lim(self, l_est, gpu, n_iter=100):
l_est = l_est.cpu().detach()
l_est = torch.pow(l_est, 1/0.80)
# l_est [batch, channel, time]
l_mag = l_est.unsqueeze(-1)
l_phase = 2 * np.pi * torch.rand_like(l_mag) - np.pi
real = l_mag * torch.cos(l_phase)
imag = l_mag * torch.sin(l_phase)
S = torch.cat((real, imag), axis=-1)
S_mag = (real**2 + imag**2 + self.eps) ** 1/2
for i in range(n_iter):
x = ta_F.istft(S, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
S_new = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
S_new_phase = S_new/mag(S_new)
S = S_mag * S_new_phase
return x / torch.max(torch.abs(x))
if __name__ == '__main__':
batch_size = 16
channel = 2
segment_length = 512*128*6
input_wav = torch.rand((batch_size, channel, segment_length))
mode = ["cplx", "mag"]
fe = FrontEnd(channel="stereo")
output = fe(input_wav, mode=mode)
print(f"Input shape : {input_wav.shape}\nOutput shape : {output.shape}")