yichen-purdue's picture
init
34fb220
raw
history blame
No virus
8.23 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.autograd import Variable
import numpy as np
import cv2
# from vgg19_loss import VGG19Loss
# import pytorch_ssim
from .vgg19_loss import VGG19Loss
from . import pytorch_ssim
from abc import ABC, abstractmethod
from collections import OrderedDict
class abs_loss(ABC):
def loss(self, gt_img, pred_img):
pass
class norm_loss(abs_loss):
def __init__(self, norm=1):
self.norm = norm
def loss(self, gt_img, pred_img):
""" M * (I-I') """
b, c, h, w = gt_img.shape
return torch.norm(gt_img-pred_img, self.norm)/(h * w * b)
class ssim_loss(abs_loss):
def __init__(self, window_size=11, channel=1):
""" Let's try mean ssim!
"""
self.channel = channel
self.window_size = window_size
self.window = self.create_mean_window(window_size, channel)
def loss(self, gt_img, pred_img):
b, c, h, w = gt_img.shape
if c != self.channel:
self.channel = c
self.window = self.create_mean_window(self.window_size, self.channel)
self.window = self.window.to(gt_img).type_as(gt_img)
l = 1.0 - self.ssim_compute(gt_img, pred_img)
return l
def create_mean_window(self, window_size, channel):
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
window = window/(window_size * window_size)
return window
def ssim_compute(self, gt_img, pred_img):
window = self.window
window_size = self.window_size
channel = self.channel
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(gt_img*gt_img, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(pred_img*pred_img, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(gt_img*pred_img, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
class hierarchical_ssim_loss(abs_loss):
def __init__(self, patch_list: list):
self.ssim_loss_list = [pytorch_ssim.SSIM(window_size=ws) for ws in patch_list]
def loss(self, gt_img, pred_img):
b, c, h, w = gt_img.shape
total_loss = 0.0
for loss_func in self.ssim_loss_list:
total_loss += (1.0-loss_func(gt_img, pred_img))
return total_loss/b
class vgg_loss(abs_loss):
def __init__(self):
self.vgg19_ = VGG19Loss()
def loss(self, gt_img, pred_img):
b, c, h, w = gt_img.shape
v = self.vgg19_(gt_img, pred_img, pred_img.device)
return v/b
class grad_loss(abs_loss):
def __init__(self, k=4):
self.k = 4
def loss(self, disp_img, rgb_img=None):
""" Note, gradient loss should be weighted by an edge-aware weight
"""
b, c, h, w = disp_img.shape
grad_loss = 0.0
for i in range(self.k):
div_factor = 2 ** i
cur_transform = T.Resize([h // div_factor, ])
# cur_diff = cur_transform(diff)
# cur_diff_dx, cur_diff_dy = self.img_grad(cur_diff)
cur_disp = cur_transform(disp_img)
cur_disp_dx, cur_disp_dy = self.img_grad(cur_disp)
if rgb_img is not None:
cur_rgb = cur_transform(rgb_img)
cur_rgb_dx, cur_rgb_dy = self.img_grad(cur_rgb)
cur_rgb_dx = torch.exp(-torch.mean(torch.abs(cur_rgb_dx), dim=1, keepdims=True))
cur_rgb_dy = torch.exp(-torch.mean(torch.abs(cur_rgb_dy), dim=1, keepdims=True))
grad_loss += (torch.sum(torch.abs(cur_disp_dx) * cur_rgb_dx) + torch.sum(torch.abs(cur_disp_dy) * cur_rgb_dy)) / (h * w * self.k)
else:
grad_loss += (torch.sum(torch.abs(cur_disp_dx)) + torch.sum(torch.abs(cur_disp_dy))) / (h * w * self.k)
return grad_loss/b
def gloss(self, gt, pred):
""" Loss on the gradient domain
"""
b, c, h, w = gt.shape
gt_dx, gt_dy = self.img_grad(gt)
pred_dx, pred_dy = self.img_grad(pred)
loss = (gt_dx-pred_dx) ** 2 + (gt_dy - pred_dy) ** 2
return loss.sum()/(b * h * w)
def laploss(self, pred):
b, c, h, w = pred.shape
lap = self.img_laplacian(pred)
return torch.abs(lap).sum()/(b * h * w)
def img_laplacian(self, img):
b, c, h, w = img.shape
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
return lap
def img_grad(self, img):
""" Comptue image gradient by sobel filtering
img: B x C x H x W
"""
b, c, h, w = img.shape
ysobel = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
xsobel = ysobel.transpose(0,1)
xsobel_kernel = xsobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
ysobel_kernel = ysobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
dx = F.conv2d(img, xsobel_kernel, padding=1, stride=1)
dy = F.conv2d(img, ysobel_kernel, padding=1, stride=1)
return dx, dy
class sharp_loss(abs_loss):
""" Sharpness term
1. laplacian
2. image contrast
3. image variance
"""
def __init__(self, window_size=11, channel=1):
self.window_size = window_size
self.channel = channel
self.window = self.create_mean_window(window_size, self.channel)
def loss(self, gt_img, pred_img):
""" Note, gradient loss should be weighted by an edge-aware weight
"""
b, c, h, w = gt_img.shape
if c != self.channel:
self.channel = c
self.window = self.create_mean_window(self.window_size, self.channel)
self.window = self.window.to(gt_img).type_as(gt_img)
channel = self.channel
window = self.window
window_size = self.window_size
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) + 1e-6
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) + 1e-6
constrast1 = torch.absolute((gt_img - mu1)/mu1)
constrast2 = torch.absolute((pred_img - mu2)/mu2)
variance1 = (gt_img-mu1) ** 2
variance2 = (pred_img-mu2) ** 2
laplacian1 = self.img_laplacian(gt_img)
laplacian2 = self.img_laplacian(pred_img)
S1 = -laplacian1 - constrast1 - variance1
S2 = -laplacian2 - constrast2 - variance2
# import pdb; pdb.set_trace()
total = torch.absolute(S1-S2).mean()
return total
def img_laplacian(self, img):
b, c, h, w = img.shape
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
return lap
def create_mean_window(self, window_size, channel):
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
window = window/(window_size * window_size)
return window
if __name__ == '__main__':
a = torch.rand(3,3,128,128)
b = torch.rand(3,3,128,128)
ssim = ssim_loss()
loss = ssim.loss(a, b)
print(loss.shape, loss)
loss = ssim.loss(a, a)
print(loss.shape, loss)
loss = ssim.loss(b, b)
print(loss.shape, loss)
grad = grad_loss()
loss = grad.loss(a, [b, b])
print(loss.shape, loss)
sharp = sharp_loss()
loss = sharp.loss(a, b)
print(loss.shape, loss)