|
import torch |
|
from torch.nn import functional as F |
|
from stft_loss import MultiResolutionSTFTLoss |
|
|
|
|
|
import commons |
|
|
|
|
|
def feature_loss(fmap_r, fmap_g): |
|
loss = 0 |
|
for dr, dg in zip(fmap_r, fmap_g): |
|
for rl, gl in zip(dr, dg): |
|
rl = rl.float().detach() |
|
gl = gl.float() |
|
loss += torch.mean(torch.abs(rl - gl)) |
|
|
|
return loss * 2 |
|
|
|
|
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs): |
|
loss = 0 |
|
r_losses = [] |
|
g_losses = [] |
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
|
dr = dr.float() |
|
dg = dg.float() |
|
r_loss = torch.mean((1-dr)**2) |
|
g_loss = torch.mean(dg**2) |
|
loss += (r_loss + g_loss) |
|
r_losses.append(r_loss.item()) |
|
g_losses.append(g_loss.item()) |
|
|
|
return loss, r_losses, g_losses |
|
|
|
|
|
def generator_loss(disc_outputs): |
|
loss = 0 |
|
gen_losses = [] |
|
for dg in disc_outputs: |
|
dg = dg.float() |
|
l = torch.mean((1-dg)**2) |
|
gen_losses.append(l) |
|
loss += l |
|
|
|
return loss, gen_losses |
|
|
|
|
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): |
|
""" |
|
z_p, logs_q: [b, h, t_t] |
|
m_p, logs_p: [b, h, t_t] |
|
""" |
|
z_p = z_p.float() |
|
logs_q = logs_q.float() |
|
m_p = m_p.float() |
|
logs_p = logs_p.float() |
|
z_mask = z_mask.float() |
|
|
|
kl = logs_p - logs_q - 0.5 |
|
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) |
|
kl = torch.sum(kl * z_mask) |
|
l = kl / torch.sum(z_mask) |
|
return l |
|
|
|
def subband_stft_loss(h, y_mb, y_hat_mb): |
|
sub_stft_loss = MultiResolutionSTFTLoss(h.train.fft_sizes, h.train.hop_sizes, h.train.win_lengths) |
|
y_mb = y_mb.view(-1, y_mb.size(2)) |
|
y_hat_mb = y_hat_mb.view(-1, y_hat_mb.size(2)) |
|
sub_sc_loss, sub_mag_loss = sub_stft_loss(y_hat_mb[:, :y_mb.size(-1)], y_mb) |
|
return sub_sc_loss+sub_mag_loss |
|
|
|
|