|
import torch |
|
|
|
disc_loss_criterion = torch.nn.BCELoss() |
|
gen_loss_criterion = torch.nn.BCELoss() |
|
real_label = 1 |
|
fake_label = 0 |
|
|
|
|
|
def discriminator_loss(gen_images, real_images, disc_net): |
|
real = real_images.new_full((real_images.shape[0], 1), real_label) |
|
gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) |
|
|
|
realloss = disc_loss_criterion(disc_net(real_images), real) |
|
genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) |
|
|
|
return (genloss + realloss) / 2 |
|
|
|
|
|
def generator_loss(gen_images, disc_net): |
|
output = disc_net(gen_images) |
|
cats = output.new_full(output.shape, real_label) |
|
return gen_loss_criterion(output, cats) |
|
|