Bread / models /losses.py
huqiming513's picture
Upload 7 files
03b684c
raw
history blame
9.47 kB
import torch
import torch.nn as nn
from pytorch_msssim import SSIM, MS_SSIM
from torch.nn import L1Loss, MSELoss
from torchvision.models import vgg16
import torch.nn.functional as F
def compute_gradient(img):
gradx = img[..., 1:, :] - img[..., :-1, :]
grady = img[..., 1:] - img[..., :-1]
return gradx, grady
class GradientLoss(nn.Module):
def __init__(self):
super(GradientLoss, self).__init__()
self.loss = nn.L1Loss()
def forward(self, predict, target):
predict_gradx, predict_grady = compute_gradient(predict)
target_gradx, target_grady = compute_gradient(target)
return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady)
class SSIMLoss(nn.Module):
def __init__(self, channels):
super(SSIMLoss, self).__init__()
self.ssim = SSIM(data_range=1., size_average=True, channel=channels)
def forward(self, output, target):
ssim_loss = 1 - self.ssim(output, target)
return ssim_loss
class SSIML1Loss(nn.Module):
def __init__(self, channels):
super(SSIML1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.ssim = SSIM(data_range=1., size_average=True, channel=channels)
self.alpha = 1.4
def forward(self, output, target):
l1_loss = self.l1_loss_func(output, target)
ssim_loss = 1 - self.ssim(output, target)
total_loss = l1_loss + self.alpha * ssim_loss
return total_loss
class GradSSIML1Loss(nn.Module):
def __init__(self, channels):
super(GradSSIML1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.ssim = SSIM(data_range=1., size_average=True, channel=channels)
self.grad_loss_func = GradientLoss()
self.alpha = 1.4
def forward(self, output, target):
l1_loss = self.l1_loss_func(output, target)
ssim_loss = 1 - self.ssim(output, target)
grad_loss = self.grad_loss_func(output, target)
total_loss = l1_loss + self.alpha * ssim_loss + 0.2 * grad_loss
return total_loss
class SSIML2Loss(nn.Module):
def __init__(self, channels):
super(SSIML2Loss, self).__init__()
self.l2_loss_func = nn.MSELoss()
self.ssim = SSIM(data_range=1., size_average=True, channel=channels)
self.alpha = 1.
def forward(self, output, target):
l2_loss = self.l2_loss_func(output, target)
ssim_loss = 1 - self.ssim(output, target)
total_loss = l2_loss + self.alpha * ssim_loss
return total_loss
class MSSSIML1Loss(nn.Module):
def __init__(self, channels):
super(MSSSIML1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.ms_ssim = MS_SSIM(data_range=1., size_average=True, channel=channels)
self.alpha = 1.0
def forward(self, output, target):
ms_ssim_loss = 1 - self.ms_ssim(output, target)
l1_loss = self.l1_loss_func(output, target)
total_loss = l1_loss + self.alpha * ms_ssim_loss
return total_loss
class MSSSIML2Loss(nn.Module):
def __init__(self, channels):
super(MSSSIML2Loss, self).__init__()
self.l2_loss_func = nn.MSELoss()
self.ms_ssim = MS_SSIM(data_range=1., size_average=True, channel=channels)
# self.alpha = 0.84
self.alpha = 1.2
def forward(self, output, target):
l2_loss = self.l2_loss_func(output, target)
ms_ssim_loss = 1 - self.ms_ssim(output, target)
total_loss = l2_loss + self.alpha * ms_ssim_loss
return total_loss
class PerLoss(torch.nn.Module):
def __init__(self):
super(PerLoss, self).__init__()
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.to('cuda')
for param in vgg_model.parameters():
param.requires_grad = False
self.vgg_layers = vgg_model
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3"
}
def output_features(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return list(output.values())
def forward(self, data, gt):
loss = []
if data.shape[1] == 1:
data = data.repeat(1, 3, 1, 1)
gt = gt.repeat(1, 3, 1, 1)
dehaze_features = self.output_features(data)
gt_features = self.output_features(gt)
for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
loss.append(F.mse_loss(dehaze_feature, gt_feature))
return sum(loss) / len(loss)
class PerL1Loss(torch.nn.Module):
def __init__(self):
super(PerL1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.per_loss_func = PerLoss().to('cuda')
def forward(self, output, target):
l1_loss = self.l1_loss_func(output, target)
per_loss = self.per_loss_func(output, target)
# total_loss = l1_loss + 0.04 * per_loss
total_loss = l1_loss + 0.2 * per_loss
return total_loss
class MSPerL1Loss(torch.nn.Module):
def __init__(self, channels):
super(MSPerL1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.ms_ssim = MS_SSIM(data_range=1., size_average=True, channel=channels)
self.per_loss_func = PerLoss().to('cuda')
def forward(self, output, target):
ms_ssim_loss = 1 - self.ms_ssim(output, target)
l1_loss = self.l1_loss_func(output, target)
per_loss = self.per_loss_func(output, target)
total_loss = l1_loss + 1.2 * ms_ssim_loss + 0.04 * per_loss
return total_loss
class MSPerL2Loss(torch.nn.Module):
def __init__(self):
super(MSPerL2Loss, self).__init__()
self.l2_loss_func = nn.MSELoss()
self.ms_ssim = MS_SSIM(data_range=1., size_average=True, channel=3)
self.per_loss_func = PerLoss().to('cuda')
def forward(self, output, target):
ms_ssim_loss = 1 - self.ms_ssim(output, target)
l2_loss = self.l2_loss_func(output, target)
per_loss = self.per_loss_func(output, target)
total_loss = l2_loss + 0.16 * ms_ssim_loss + 0.2 * per_loss
return total_loss
class TVLoss(torch.nn.Module):
def __init__(self):
super(TVLoss, self).__init__()
def forward(self, data):
w_variance = torch.sum(torch.pow(data[:, :, :, :-1] - data[:, :, :, 1:], 2))
h_variance = torch.sum(torch.pow(data[:, :, :-1, :] - data[:, :, 1:, :], 2))
count_h = self._tensor_size(data[:, :, 1:, :])
count_w = self._tensor_size(data[:, :, :, 1:])
tv_loss = h_variance / count_h + w_variance / count_w
return tv_loss
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
def safe_div(a, b, eps=1e-2):
return a / torch.clamp_min(b, eps)
class WTVLoss(torch.nn.Module):
def __init__(self):
super(WTVLoss, self).__init__()
self.eps = 1e-2
def forward(self, data, aux):
data_dw = data[:, :, :, :-1] - data[:, :, :, 1:]
data_dh = data[:, :, :-1, :] - data[:, :, 1:, :]
aux_dw = torch.abs(aux[:, :, :, :-1] - aux[:, :, :, 1:])
aux_dh = torch.abs(aux[:, :, :-1, :] - aux[:, :, 1:, :])
w_variance = torch.sum(torch.pow(safe_div(data_dw, aux_dw, self.eps), 2))
h_variance = torch.sum(torch.pow(safe_div(data_dh, aux_dh, self.eps), 2))
count_h = self._tensor_size(data[:, :, 1:, :])
count_w = self._tensor_size(data[:, :, :, 1:])
tv_loss = h_variance / count_h + w_variance / count_w
return tv_loss
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
class WTVLoss2(torch.nn.Module):
def __init__(self):
super(WTVLoss2, self).__init__()
self.eps = 1e-2
self.criterion = nn.MSELoss()
def forward(self, data, aux):
N, C, H, W = data.shape
data_dw = F.pad(torch.abs(data[:, :, :, :-1] - data[:, :, :, 1:]), (1, 0, 0, 0))
data_dh = F.pad(torch.abs(data[:, :, :-1, :] - data[:, :, 1:, :]), (0, 0, 1, 0))
aux_dw = F.pad(torch.abs(aux[:, :, :, :-1] - aux[:, :, :, 1:]), (1, 0, 0, 0))
aux_dh = F.pad(torch.abs(aux[:, :, :-1, :] - aux[:, :, 1:, :]), (0, 0, 1, 0))
data_d = data_dw + data_dh
aux_d = aux_dw + aux_dh
loss1 = self.criterion(data_d, aux_d)
# loss2 = torch.norm(data_d / (aux_d + self.eps), p=1) / (C * H * W)
loss2 = torch.norm(data_d / (aux_d + self.eps)) / (C * H * W)
return loss1 * 0.5 + loss2 * 4.0
class MSTVPerL1Loss(torch.nn.Module):
def __init__(self):
super(MSTVPerL1Loss, self).__init__()
self.l1_loss_func = nn.L1Loss()
self.ms_ssim = MS_SSIM(data_range=1., size_average=True, channel=3)
self.per_loss_func = PerLoss().to('cuda')
self.tv_loss_func = TVLoss()
def forward(self, output, target):
ms_ssim_loss = 1 - self.ms_ssim(output, target)
l1_loss = self.l1_loss_func(output, target)
per_loss = self.per_loss_func(output, target)
tv_loss = self.tv_loss_func(output)
total_loss = l1_loss + 1.2 * ms_ssim_loss + 0.04 * per_loss + 1e-7 * tv_loss
return total_loss
if __name__ == "__main__":
MSTVPerL1Loss()