jerryhai
Track binary files with Git LFS
90f7c1e
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
import torch
import torchaudio
import numpy as np
from librosa.filters import mel as librosa_mel_fn
from .base import BaseModule
def mse_loss(x, y, mask, n_feats):
loss = torch.sum(((x - y)**2) * mask)
return loss / (torch.sum(mask) * n_feats)
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
class PseudoInversion(BaseModule):
def __init__(self, n_mels, sampling_rate, n_fft):
super(PseudoInversion, self).__init__()
self.n_mels = n_mels
self.sampling_rate = sampling_rate
self.n_fft = n_fft
mel_basis = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=0, fmax=8000)
mel_basis_inverse = np.linalg.pinv(mel_basis)
mel_basis_inverse = torch.from_numpy(mel_basis_inverse).float()
self.register_buffer("mel_basis_inverse", mel_basis_inverse)
def forward(self, log_mel_spectrogram):
mel_spectrogram = torch.exp(log_mel_spectrogram)
stftm = torch.matmul(self.mel_basis_inverse, mel_spectrogram)
return stftm
class InitialReconstruction(BaseModule):
def __init__(self, n_fft, hop_size):
super(InitialReconstruction, self).__init__()
self.n_fft = n_fft
self.hop_size = hop_size
window = torch.hann_window(n_fft).float()
self.register_buffer("window", window)
def forward(self, stftm):
real_part = torch.ones_like(stftm, device=stftm.device)
imag_part = torch.zeros_like(stftm, device=stftm.device)
stft = torch.stack([real_part, imag_part], -1)*stftm.unsqueeze(-1)
istft = torch.istft(stft, n_fft=self.n_fft,
hop_length=self.hop_size, win_length=self.n_fft,
window=self.window, center=True)
return istft.unsqueeze(1)
# Fast Griffin-Lim algorithm as a PyTorch module
class FastGL(BaseModule):
def __init__(self, n_mels, sampling_rate, n_fft, hop_size, momentum=0.99):
super(FastGL, self).__init__()
self.n_mels = n_mels
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.hop_size = hop_size
self.momentum = momentum
self.pi = PseudoInversion(n_mels, sampling_rate, n_fft)
self.ir = InitialReconstruction(n_fft, hop_size)
window = torch.hann_window(n_fft).float()
self.register_buffer("window", window)
@torch.no_grad()
def forward(self, s, n_iters=32):
c = self.pi(s)
x = self.ir(c)
x = x.squeeze(1)
c = c.unsqueeze(-1)
prev_angles = torch.zeros_like(c, device=c.device)
for _ in range(n_iters):
s = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size,
win_length=self.n_fft, window=self.window,
center=True)
real_part, imag_part = s.unbind(-1)
stftm = torch.sqrt(torch.clamp(real_part**2 + imag_part**2, min=1e-8))
angles = s / stftm.unsqueeze(-1)
s = c * (angles + self.momentum * (angles - prev_angles))
x = torch.istft(s, n_fft=self.n_fft, hop_length=self.hop_size,
win_length=self.n_fft, window=self.window,
center=True)
prev_angles = angles
return x.unsqueeze(1)