Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchaudio | |
from transformers import AutoModel | |
class SpectralConvergengeLoss(torch.nn.Module): | |
"""Spectral convergence loss module.""" | |
def __init__(self): | |
"""Initilize spectral convergence loss module.""" | |
super(SpectralConvergengeLoss, self).__init__() | |
def forward(self, x_mag, y_mag): | |
"""Calculate forward propagation. | |
Args: | |
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). | |
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). | |
Returns: | |
Tensor: Spectral convergence loss value. | |
""" | |
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) | |
class STFTLoss(torch.nn.Module): | |
"""STFT loss module.""" | |
def __init__( | |
self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window | |
): | |
"""Initialize STFT loss module.""" | |
super(STFTLoss, self).__init__() | |
self.fft_size = fft_size | |
self.shift_size = shift_size | |
self.win_length = win_length | |
self.to_mel = torchaudio.transforms.MelSpectrogram( | |
sample_rate=24000, | |
n_fft=fft_size, | |
win_length=win_length, | |
hop_length=shift_size, | |
window_fn=window, | |
) | |
self.spectral_convergenge_loss = SpectralConvergengeLoss() | |
def forward(self, x, y): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Predicted signal (B, T). | |
y (Tensor): Groundtruth signal (B, T). | |
Returns: | |
Tensor: Spectral convergence loss value. | |
Tensor: Log STFT magnitude loss value. | |
""" | |
x_mag = self.to_mel(x) | |
mean, std = -4, 4 | |
x_mag = (torch.log(1e-5 + x_mag) - mean) / std | |
y_mag = self.to_mel(y) | |
mean, std = -4, 4 | |
y_mag = (torch.log(1e-5 + y_mag) - mean) / std | |
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) | |
return sc_loss | |
class MultiResolutionSTFTLoss(torch.nn.Module): | |
"""Multi resolution STFT loss module.""" | |
def __init__( | |
self, | |
fft_sizes=[1024, 2048, 512], | |
hop_sizes=[120, 240, 50], | |
win_lengths=[600, 1200, 240], | |
window=torch.hann_window, | |
): | |
"""Initialize Multi resolution STFT loss module. | |
Args: | |
fft_sizes (list): List of FFT sizes. | |
hop_sizes (list): List of hop sizes. | |
win_lengths (list): List of window lengths. | |
window (str): Window function type. | |
""" | |
super(MultiResolutionSTFTLoss, self).__init__() | |
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) | |
self.stft_losses = torch.nn.ModuleList() | |
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): | |
self.stft_losses += [STFTLoss(fs, ss, wl, window)] | |
def forward(self, x, y): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Predicted signal (B, T). | |
y (Tensor): Groundtruth signal (B, T). | |
Returns: | |
Tensor: Multi resolution spectral convergence loss value. | |
Tensor: Multi resolution log STFT magnitude loss value. | |
""" | |
sc_loss = 0.0 | |
for f in self.stft_losses: | |
sc_l = f(x, y) | |
sc_loss += sc_l | |
sc_loss /= len(self.stft_losses) | |
return sc_loss | |
def feature_loss(fmap_r, fmap_g): | |
loss = 0 | |
for dr, dg in zip(fmap_r, fmap_g): | |
for rl, gl in zip(dr, dg): | |
loss += torch.mean(torch.abs(rl - gl)) | |
return loss * 2 | |
def discriminator_loss(disc_real_outputs, disc_generated_outputs): | |
loss = 0 | |
r_losses = [] | |
g_losses = [] | |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
r_loss = torch.mean((1 - dr) ** 2) | |
g_loss = torch.mean(dg**2) | |
loss += r_loss + g_loss | |
r_losses.append(r_loss.item()) | |
g_losses.append(g_loss.item()) | |
return loss, r_losses, g_losses | |
def generator_loss(disc_outputs): | |
loss = 0 | |
gen_losses = [] | |
for dg in disc_outputs: | |
l = torch.mean((1 - dg) ** 2) | |
gen_losses.append(l) | |
loss += l | |
return loss, gen_losses | |
""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """ | |
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): | |
loss = 0 | |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
tau = 0.04 | |
m_DG = torch.median((dr - dg)) | |
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) | |
loss += tau - F.relu(tau - L_rel) | |
return loss | |
def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): | |
loss = 0 | |
for dg, dr in zip(disc_real_outputs, disc_generated_outputs): | |
tau = 0.04 | |
m_DG = torch.median((dr - dg)) | |
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) | |
loss += tau - F.relu(tau - L_rel) | |
return loss | |
class GeneratorLoss(torch.nn.Module): | |
def __init__(self, mpd, msd): | |
super(GeneratorLoss, self).__init__() | |
self.mpd = mpd | |
self.msd = msd | |
def forward(self, y, y_hat): | |
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat) | |
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat) | |
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) | |
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) | |
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) | |
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) | |
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss( | |
y_ds_hat_r, y_ds_hat_g | |
) | |
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel | |
return loss_gen_all.mean() | |
class DiscriminatorLoss(torch.nn.Module): | |
def __init__(self, mpd, msd): | |
super(DiscriminatorLoss, self).__init__() | |
self.mpd = mpd | |
self.msd = msd | |
def forward(self, y, y_hat): | |
# MPD | |
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat) | |
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( | |
y_df_hat_r, y_df_hat_g | |
) | |
# MSD | |
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat) | |
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( | |
y_ds_hat_r, y_ds_hat_g | |
) | |
loss_rel = discriminator_TPRLS_loss( | |
y_df_hat_r, y_df_hat_g | |
) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) | |
d_loss = loss_disc_s + loss_disc_f + loss_rel | |
return d_loss.mean() | |
class WavLMLoss(torch.nn.Module): | |
def __init__(self, model, wd, model_sr, slm_sr=16000): | |
super(WavLMLoss, self).__init__() | |
self.wavlm = AutoModel.from_pretrained(model) | |
self.wd = wd | |
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr) | |
def forward(self, wav, y_rec): | |
with torch.no_grad(): | |
wav_16 = self.resample(wav) | |
wav_embeddings = self.wavlm( | |
input_values=wav_16, output_hidden_states=True | |
).hidden_states | |
y_rec_16 = self.resample(y_rec) | |
y_rec_embeddings = self.wavlm( | |
input_values=y_rec_16.squeeze(), output_hidden_states=True | |
).hidden_states | |
floss = 0 | |
for er, eg in zip(wav_embeddings, y_rec_embeddings): | |
floss += torch.mean(torch.abs(er - eg)) | |
return floss.mean() | |
def generator(self, y_rec): | |
y_rec_16 = self.resample(y_rec) | |
y_rec_embeddings = self.wavlm( | |
input_values=y_rec_16, output_hidden_states=True | |
).hidden_states | |
y_rec_embeddings = ( | |
torch.stack(y_rec_embeddings, dim=1) | |
.transpose(-1, -2) | |
.flatten(start_dim=1, end_dim=2) | |
) | |
y_df_hat_g = self.wd(y_rec_embeddings) | |
loss_gen = torch.mean((1 - y_df_hat_g) ** 2) | |
return loss_gen | |
def discriminator(self, wav, y_rec): | |
with torch.no_grad(): | |
wav_16 = self.resample(wav) | |
wav_embeddings = self.wavlm( | |
input_values=wav_16, output_hidden_states=True | |
).hidden_states | |
y_rec_16 = self.resample(y_rec) | |
y_rec_embeddings = self.wavlm( | |
input_values=y_rec_16, output_hidden_states=True | |
).hidden_states | |
y_embeddings = ( | |
torch.stack(wav_embeddings, dim=1) | |
.transpose(-1, -2) | |
.flatten(start_dim=1, end_dim=2) | |
) | |
y_rec_embeddings = ( | |
torch.stack(y_rec_embeddings, dim=1) | |
.transpose(-1, -2) | |
.flatten(start_dim=1, end_dim=2) | |
) | |
y_d_rs = self.wd(y_embeddings) | |
y_d_gs = self.wd(y_rec_embeddings) | |
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs | |
r_loss = torch.mean((1 - y_df_hat_r) ** 2) | |
g_loss = torch.mean((y_df_hat_g) ** 2) | |
loss_disc_f = r_loss + g_loss | |
return loss_disc_f.mean() | |
def discriminator_forward(self, wav): | |
with torch.no_grad(): | |
wav_16 = self.resample(wav) | |
wav_embeddings = self.wavlm( | |
input_values=wav_16, output_hidden_states=True | |
).hidden_states | |
y_embeddings = ( | |
torch.stack(wav_embeddings, dim=1) | |
.transpose(-1, -2) | |
.flatten(start_dim=1, end_dim=2) | |
) | |
y_d_rs = self.wd(y_embeddings) | |
return y_d_rs | |