APISR / loss /gan_loss.py
HikariDawn's picture
feat: initial push
561c629
raw
history blame
3.68 kB
# -*- coding: utf-8 -*-
import math
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
import cv2
import numpy as np
import os, sys
root_path = os.path.abspath('.')
sys.path.append(root_path)
from loss.perceptual_loss import VGGFeatureExtractor
from degradation.ESR.utils import np2tensor, tensor2np, save_img
class GANLoss(nn.Module):
"""Define GAN loss.
From Real-ESRGAN code
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self, gan_type="vanilla", real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
# gan type is vanilla usually
if gan_type == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_type == "lsgan":
self.loss = nn.MSELoss()
else:
raise NotImplementedError("We didn't implement this GAN type")
# Skip wgan part here
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
class MultiScaleGANLoss(GANLoss):
"""
MultiScaleGANLoss accepts a list of predictions
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
def forward(self, input, target_is_real, is_disc=False):
"""
The input is a list of tensors, or a list of (a list of tensors)
"""
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
# Only compute GAN loss for the last layer
# in case of multiscale feature matching
pred_i = pred_i[-1]
# Safe operation: 0-dim tensor calling self.mean() does nothing
loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
loss += loss_tensor
return loss / len(input)
else:
return super().forward(input, target_is_real, is_disc)