abhishekrs4's picture
added image_colourization_cgan module
4a5aa3d
raw
history blame contribute delete
No virus
2.34 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
"""
Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, loss_mode="vanilla", real_label=1.0, fake_label=0.0):
"""
---------
Arguments
---------
loss_mode : str
GAN loss mode (default="vanilla")
real_label : bool
label for real image
fake_label : bool
label for fake image
"""
super().__init__()
self.loss_mode = loss_mode
self.register_buffer("real_label", torch.tensor(real_label))
self.register_buffer("fake_label", torch.tensor(fake_label))
self.loss = None
if self.loss_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
else:
raise NotImplementedError(
f"GANLoss with {self.loss_mode} mode - not implemented yet"
)
def get_target_tensor(self, prediction, target_is_real):
"""
---------
Arguments
---------
prediction : tensor
prediction from a discriminator
target_is_real : bool
whether the groundtruth label is for a real image or a fake image
-------
Returns
-------
tensor : A label tensor filled with groundtruth label with the same size as that of input
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
"""
---------
Arguments
---------
prediction : tensor
prediction from a discriminator
target_is_real : bool
whether the groundtruth label is for a real image or a fake image
-------
Returns
-------
loss : the computed loss
"""
if self.loss_mode == "vanilla":
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
else:
loss = 0
return loss