BayesCap / losses.py
udion's picture
hfspace gradio demo
99e984c
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch import Tensor
class ContentLoss(nn.Module):
"""Constructs a content loss function based on the VGG19 network.
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
Paper reference list:
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
"""
def __init__(self) -> None:
super(ContentLoss, self).__init__()
# Load the VGG19 model trained on the ImageNet dataset.
vgg19 = models.vgg19(pretrained=True).eval()
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
# Freeze model parameters.
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = False
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
# Standardized operations
sr = sr.sub(self.mean).div(self.std)
hr = hr.sub(self.mean).div(self.std)
# Find the feature map difference between the two images
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
return loss
class GenGaussLoss(nn.Module):
def __init__(
self, reduction='mean',
alpha_eps = 1e-4, beta_eps=1e-4,
resi_min = 1e-4, resi_max=1e3
) -> None:
super(GenGaussLoss, self).__init__()
self.reduction = reduction
self.alpha_eps = alpha_eps
self.beta_eps = beta_eps
self.resi_min = resi_min
self.resi_max = resi_max
def forward(
self,
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
):
one_over_alpha1 = one_over_alpha + self.alpha_eps
beta1 = beta + self.beta_eps
resi = torch.abs(mean - target)
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
## check if resi has nans
if torch.sum(resi != resi) > 0:
print('resi has nans!!')
return None
log_one_over_alpha = torch.log(one_over_alpha1)
log_beta = torch.log(beta1)
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
print('log_one_over_alpha has nan')
if torch.sum(lgamma_beta != lgamma_beta) > 0:
print('lgamma_beta has nan')
if torch.sum(log_beta != log_beta) > 0:
print('log_beta has nan')
l = resi - log_one_over_alpha + lgamma_beta - log_beta
if self.reduction == 'mean':
return l.mean()
elif self.reduction == 'sum':
return l.sum()
else:
print('Reduction not supported')
return None
class TempCombLoss(nn.Module):
def __init__(
self, reduction='mean',
alpha_eps = 1e-4, beta_eps=1e-4,
resi_min = 1e-4, resi_max=1e3
) -> None:
super(TempCombLoss, self).__init__()
self.reduction = reduction
self.alpha_eps = alpha_eps
self.beta_eps = beta_eps
self.resi_min = resi_min
self.resi_max = resi_max
self.L_GenGauss = GenGaussLoss(
reduction=self.reduction,
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
resi_min=self.resi_min, resi_max=self.resi_max
)
self.L_l1 = nn.L1Loss(reduction=self.reduction)
def forward(
self,
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
T1: float, T2: float
):
l1 = self.L_l1(mean, target)
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
l = T1*l1 + T2*l2
return l
# x1 = torch.randn(4,3,32,32)
# x2 = torch.rand(4,3,32,32)
# x3 = torch.rand(4,3,32,32)
# x4 = torch.randn(4,3,32,32)
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))