|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class GANLoss(nn.Module): |
|
""" |
|
Define different GAN objectives. |
|
The GANLoss class abstracts away the need to create the target label tensor |
|
that has the same size as the input. |
|
""" |
|
|
|
def __init__(self, loss_mode="vanilla", real_label=1.0, fake_label=0.0): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
loss_mode : str |
|
GAN loss mode (default="vanilla") |
|
real_label : bool |
|
label for real image |
|
fake_label : bool |
|
label for fake image |
|
""" |
|
super().__init__() |
|
self.loss_mode = loss_mode |
|
self.register_buffer("real_label", torch.tensor(real_label)) |
|
self.register_buffer("fake_label", torch.tensor(fake_label)) |
|
|
|
self.loss = None |
|
if self.loss_mode == "vanilla": |
|
self.loss = nn.BCEWithLogitsLoss() |
|
else: |
|
raise NotImplementedError( |
|
f"GANLoss with {self.loss_mode} mode - not implemented yet" |
|
) |
|
|
|
def get_target_tensor(self, prediction, target_is_real): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
prediction : tensor |
|
prediction from a discriminator |
|
target_is_real : bool |
|
whether the groundtruth label is for a real image or a fake image |
|
|
|
------- |
|
Returns |
|
------- |
|
tensor : A label tensor filled with groundtruth label with the same size as that of input |
|
""" |
|
if target_is_real: |
|
target_tensor = self.real_label |
|
else: |
|
target_tensor = self.fake_label |
|
return target_tensor.expand_as(prediction) |
|
|
|
def __call__(self, prediction, target_is_real): |
|
""" |
|
--------- |
|
Arguments |
|
--------- |
|
prediction : tensor |
|
prediction from a discriminator |
|
target_is_real : bool |
|
whether the groundtruth label is for a real image or a fake image |
|
|
|
------- |
|
Returns |
|
------- |
|
loss : the computed loss |
|
""" |
|
if self.loss_mode == "vanilla": |
|
target_tensor = self.get_target_tensor(prediction, target_is_real) |
|
loss = self.loss(prediction, target_tensor) |
|
else: |
|
loss = 0 |
|
return loss |
|
|