File size: 2,339 Bytes
4a5aa3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
"""
Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, loss_mode="vanilla", real_label=1.0, fake_label=0.0):
"""
---------
Arguments
---------
loss_mode : str
GAN loss mode (default="vanilla")
real_label : bool
label for real image
fake_label : bool
label for fake image
"""
super().__init__()
self.loss_mode = loss_mode
self.register_buffer("real_label", torch.tensor(real_label))
self.register_buffer("fake_label", torch.tensor(fake_label))
self.loss = None
if self.loss_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
else:
raise NotImplementedError(
f"GANLoss with {self.loss_mode} mode - not implemented yet"
)
def get_target_tensor(self, prediction, target_is_real):
"""
---------
Arguments
---------
prediction : tensor
prediction from a discriminator
target_is_real : bool
whether the groundtruth label is for a real image or a fake image
-------
Returns
-------
tensor : A label tensor filled with groundtruth label with the same size as that of input
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
"""
---------
Arguments
---------
prediction : tensor
prediction from a discriminator
target_is_real : bool
whether the groundtruth label is for a real image or a fake image
-------
Returns
-------
loss : the computed loss
"""
if self.loss_mode == "vanilla":
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
else:
loss = 0
return loss
|