ECON / lib /net /GANLoss.py
Yuliang's picture
Support TEXTure
487ee6d
raw
history blame
No virus
2.18 kB
""" The code is based on https://github.com/apple/ml-gsn/ with adaption. """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from lib.net.Discriminator import StyleDiscriminator
def hinge_loss(fake_pred, real_pred, mode):
if mode == 'd':
# Discriminator update
d_loss_fake = torch.mean(F.relu(1.0 + fake_pred))
d_loss_real = torch.mean(F.relu(1.0 - real_pred))
d_loss = d_loss_fake + d_loss_real
elif mode == 'g':
# Generator update
d_loss = -torch.mean(fake_pred)
return d_loss
def logistic_loss(fake_pred, real_pred, mode):
if mode == 'd':
# Discriminator update
d_loss_fake = torch.mean(F.softplus(fake_pred))
d_loss_real = torch.mean(F.softplus(-real_pred))
d_loss = d_loss_fake + d_loss_real
elif mode == 'g':
# Generator update
d_loss = torch.mean(F.softplus(-fake_pred))
return d_loss
def r1_loss(real_pred, real_img):
(grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
class GANLoss(nn.Module):
def __init__(
self,
opt,
disc_loss='logistic',
):
super().__init__()
self.opt = opt.gan
input_dim = 3
self.discriminator = StyleDiscriminator(input_dim, self.opt.img_res)
if disc_loss == 'hinge':
self.disc_loss = hinge_loss
elif disc_loss == 'logistic':
self.disc_loss = logistic_loss
def forward(self, input):
disc_in_real = input['norm_real']
disc_in_fake = input['norm_fake']
logits_real = self.discriminator(disc_in_real)
logits_fake = self.discriminator(disc_in_fake)
disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
log = {
"disc_loss": disc_loss.detach(),
"logits_real": logits_real.mean().detach(),
"logits_fake": logits_fake.mean().detach(),
}
return disc_loss * self.opt.lambda_gan, log