# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # All contributions by Andy Brock: # Copyright (c) 2019 Andy Brock # # MIT License import torch import torch.nn.functional as F # DCGAN loss def loss_dcgan_dis(dis_fake, dis_real): L1 = torch.mean(F.softplus(-dis_real)) L2 = torch.mean(F.softplus(dis_fake)) return L1, L2 def loss_dcgan_gen(dis_fake): loss = torch.mean(F.softplus(-dis_fake)) return loss # Hinge Loss def loss_hinge_dis(dis_fake, dis_real): loss_real = torch.mean(F.relu(1.0 - dis_real)) loss_fake = torch.mean(F.relu(1.0 + dis_fake)) return loss_real, loss_fake # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss # loss = torch.mean(F.relu(1. - dis_real)) # loss += torch.mean(F.relu(1. + dis_fake)) # return loss def loss_hinge_gen(dis_fake): loss = -torch.mean(dis_fake) return loss # Default to hinge loss generator_loss = loss_hinge_gen discriminator_loss = loss_hinge_dis