import torch import torch.nn.functional as F import torch.nn as nn from models.vgg import Vgg19 from utils.image_processing import gram def to_gray_scale(image): # https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33 # Image are assum in range 1, -1 image = (image + 1.0) / 2.0 # To [0, 1] r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) l_img = l_img.to(image.dtype) l_img = l_img.expand(image.shape) l_img = l_img / 0.5 - 1.0 # To [-1, 1] return l_img class ColorLoss(nn.Module): def __init__(self): super(ColorLoss, self).__init__() self.l1 = nn.L1Loss() self.huber = nn.SmoothL1Loss() # self._rgb_to_yuv_kernel = torch.tensor([ # [0.299, -0.14714119, 0.61497538], # [0.587, -0.28886916, -0.51496512], # [0.114, 0.43601035, -0.10001026] # ]).float() self._rgb_to_yuv_kernel = torch.tensor([ [0.299, 0.587, 0.114], [-0.14714119, -0.28886916, 0.43601035], [0.61497538, -0.51496512, -0.10001026], ]).float() def to(self, device): new_self = super(ColorLoss, self).to(device) new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device) return new_self def rgb_to_yuv(self, image): ''' https://en.wikipedia.org/wiki/YUV output: Image of shape (H, W, C) (channel last) ''' # -1 1 -> 0 1 image = (image + 1.0) / 2.0 image = image.permute(0, 2, 3, 1) # To channel last yuv_img = image @ self._rgb_to_yuv_kernel.T return yuv_img def forward(self, image, image_g): image = self.rgb_to_yuv(image) image_g = self.rgb_to_yuv(image_g) # After convert to yuv, both images have channel last return ( self.l1(image[:, :, :, 0], image_g[:, :, :, 0]) + self.huber(image[:, :, :, 1], image_g[:, :, :, 1]) + self.huber(image[:, :, :, 2], image_g[:, :, :, 2]) ) class AnimeGanLoss: def __init__(self, args, device, gray_adv=False): if isinstance(device, str): device = torch.device(device) self.content_loss = nn.L1Loss().to(device) self.gram_loss = nn.L1Loss().to(device) self.color_loss = ColorLoss().to(device) self.wadvg = args.wadvg self.wadvd = args.wadvd self.wcon = args.wcon self.wgra = args.wgra self.wcol = args.wcol self.wtvar = args.wtvar # If true, use gray scale image to calculate adversarial loss self.gray_adv = gray_adv self.vgg19 = Vgg19().to(device).eval() self.adv_type = args.gan_loss self.bce_loss = nn.BCEWithLogitsLoss() def compute_loss_G(self, fake_img, img, fake_logit, anime_gray): ''' Compute loss for Generator @Args: - fake_img: generated image - img: real image - fake_logit: output of Discriminator given fake image - anime_gray: grayscale of anime image @Returns: - Adversarial Loss of fake logits - Content loss between real and fake features (vgg19) - Gram loss between anime and fake features (Vgg19) - Color loss between image and fake image - Total variation loss of fake image ''' fake_feat = self.vgg19(fake_img) gray_feat = self.vgg19(anime_gray) img_feat = self.vgg19(img) # fake_gray_feat = self.vgg19(to_gray_scale(fake_img)) return [ # Want to be real image. self.wadvg * self.adv_loss_g(fake_logit), self.wcon * self.content_loss(img_feat, fake_feat), self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)), self.wcol * self.color_loss(img, fake_img), self.wtvar * self.total_variation_loss(fake_img) ] def compute_loss_D( self, fake_img_d, real_anime_d, real_anime_gray_d, real_anime_smooth_gray_d=None ): if self.gray_adv: # Treat gray scale image as real return ( self.adv_loss_d_real(real_anime_gray_d) + self.adv_loss_d_fake(fake_img_d) + 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d) ) else: return ( # Classify real anime as real self.adv_loss_d_real(real_anime_d) # Classify generated as fake + self.adv_loss_d_fake(fake_img_d) # Classify real anime gray as fake # + self.adv_loss_d_fake(real_anime_gray_d) # Classify real anime as fake # + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d) ) def total_variation_loss(self, fake_img): """ A smooth loss in fact. Like the smooth prior in MRF. V(y) = || y_{n+1} - y_n ||_2 """ # Channel first -> channel last fake_img = fake_img.permute(0, 2, 3, 1) def _l2(x): # sum(t ** 2) / 2 return torch.sum(x ** 2) / 2 dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...] dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...] return _l2(dh) / dh.numel() + _l2(dw) / dw.numel() def content_loss_vgg(self, image, recontruction): feat = self.vgg19(image) re_feat = self.vgg19(recontruction) feature_loss = self.content_loss(feat, re_feat) content_loss = self.content_loss(image, recontruction) return feature_loss# + 0.5 * content_loss def adv_loss_d_real(self, pred): """Push pred to class 1 (real)""" if self.adv_type == 'hinge': return torch.mean(F.relu(1.0 - pred)) elif self.adv_type == 'lsgan': # pred = torch.sigmoid(pred) return torch.mean(torch.square(pred - 1.0)) elif self.adv_type == 'bce': return self.bce_loss(pred, torch.ones_like(pred)) raise ValueError(f'Do not support loss type {self.adv_type}') def adv_loss_d_fake(self, pred): """Push pred to class 0 (fake)""" if self.adv_type == 'hinge': return torch.mean(F.relu(1.0 + pred)) elif self.adv_type == 'lsgan': # pred = torch.sigmoid(pred) return torch.mean(torch.square(pred)) elif self.adv_type == 'bce': return self.bce_loss(pred, torch.zeros_like(pred)) raise ValueError(f'Do not support loss type {self.adv_type}') def adv_loss_g(self, pred): """Push pred to class 1 (real)""" if self.adv_type == 'hinge': return -torch.mean(pred) elif self.adv_type == 'lsgan': # pred = torch.sigmoid(pred) return torch.mean(torch.square(pred - 1.0)) elif self.adv_type == 'bce': return self.bce_loss(pred, torch.ones_like(pred)) raise ValueError(f'Do not support loss type {self.adv_type}') class LossSummary: def __init__(self): self.reset() def reset(self): self.loss_g_adv = [] self.loss_content = [] self.loss_gram = [] self.loss_color = [] self.loss_d_adv = [] def update_loss_G(self, adv, gram, color, content): self.loss_g_adv.append(adv.cpu().detach().numpy()) self.loss_gram.append(gram.cpu().detach().numpy()) self.loss_color.append(color.cpu().detach().numpy()) self.loss_content.append(content.cpu().detach().numpy()) def update_loss_D(self, loss): self.loss_d_adv.append(loss.cpu().detach().numpy()) def avg_loss_G(self): return ( self._avg(self.loss_g_adv), self._avg(self.loss_gram), self._avg(self.loss_color), self._avg(self.loss_content), ) def avg_loss_D(self): return self._avg(self.loss_d_adv) def get_loss_description(self): avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G() avg_adv_d = self.avg_loss_D() return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}' @staticmethod def _avg(losses): return sum(losses) / len(losses)