# -*- coding: utf-8 -*- import math import torch from torch import autograd as autograd from torch import nn as nn from torch.nn import functional as F import cv2 import numpy as np import os, sys root_path = os.path.abspath('.') sys.path.append(root_path) from loss.perceptual_loss import VGGFeatureExtractor from degradation.ESR.utils import np2tensor, tensor2np, save_img class GANLoss(nn.Module): """Define GAN loss. From Real-ESRGAN code Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. loss_weight (float): Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators. """ def __init__(self, gan_type="vanilla", real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): super(GANLoss, self).__init__() self.loss_weight = loss_weight self.real_label_val = real_label_val self.fake_label_val = fake_label_val # gan type is vanilla usually if gan_type == "vanilla": self.loss = nn.BCEWithLogitsLoss() elif gan_type == "lsgan": self.loss = nn.MSELoss() else: raise NotImplementedError("We didn't implement this GAN type") # Skip wgan part here def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, return Tensor. """ target_val = (self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) loss = self.loss(input, target_label) # loss_weight is always 1.0 for discriminators return loss if is_disc else loss * self.loss_weight class MultiScaleGANLoss(GANLoss): """ MultiScaleGANLoss accepts a list of predictions """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) def forward(self, input, target_is_real, is_disc=False): """ The input is a list of tensors, or a list of (a list of tensors) """ if isinstance(input, list): loss = 0 for pred_i in input: if isinstance(pred_i, list): # Only compute GAN loss for the last layer # in case of multiscale feature matching pred_i = pred_i[-1] # Safe operation: 0-dim tensor calling self.mean() does nothing loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() loss += loss_tensor return loss / len(input) else: return super().forward(input, target_is_real, is_disc)