monai
medical
mednist_gan / scripts /losses.py
katielink's picture
Initial version
8dff69d
raw
history blame contribute delete
672 Bytes
import torch
disc_loss_criterion = torch.nn.BCELoss()
gen_loss_criterion = torch.nn.BCELoss()
real_label = 1
fake_label = 0
def discriminator_loss(gen_images, real_images, disc_net):
real = real_images.new_full((real_images.shape[0], 1), real_label)
gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
realloss = disc_loss_criterion(disc_net(real_images), real)
genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
return (genloss + realloss) / 2
def generator_loss(gen_images, disc_net):
output = disc_net(gen_images)
cats = output.new_full(output.shape, real_label)
return gen_loss_criterion(output, cats)