Spaces:
Sleeping
Sleeping
import math | |
import torch | |
from torch.nn import functional as F | |
def feature_loss(fmap_r, fmap_g): | |
loss = 0 | |
for dr, dg in zip(fmap_r, fmap_g): | |
for rl, gl in zip(dr, dg): | |
rl = rl.float().detach() | |
gl = gl.float() | |
loss += torch.mean(torch.abs(rl - gl)) | |
return loss * 2 | |
def discriminator_loss(disc_real_outputs, disc_generated_outputs): | |
loss = 0 | |
r_losses = [] | |
g_losses = [] | |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
dr = dr.float() | |
dg = dg.float() | |
r_loss = torch.mean((1-dr)**2) | |
g_loss = torch.mean(dg**2) | |
loss += (r_loss + g_loss) | |
r_losses.append(r_loss.item()) | |
g_losses.append(g_loss.item()) | |
return loss, r_losses, g_losses | |
def generator_loss(disc_outputs): | |
loss = 0 | |
gen_losses = [] | |
for dg in disc_outputs: | |
dg = dg.float() | |
l = torch.mean((1-dg)**2) | |
gen_losses.append(l) | |
loss += l | |
return loss, gen_losses | |
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): | |
""" | |
z_p, logs_q: [b, h, t_t] | |
m_p, logs_p: [b, h, t_t] | |
""" | |
z_p = z_p.float() | |
logs_q = logs_q.float() | |
m_p = m_p.float() | |
logs_p = logs_p.float() | |
z_mask = z_mask.float() | |
kl = logs_p - logs_q - 0.5 | |
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) | |
kl = torch.sum(kl * z_mask) | |
l = kl / torch.sum(z_mask) | |
return l | |
def mle_loss(z, m, logs, logdet, mask): | |
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term | |
l = l - torch.sum(logdet) # log jacobian determinant | |
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes | |
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term | |
return l |