Spaces:
Running
Running
# ''' | |
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py | |
# ''' | |
# | |
# import torch | |
# import torch.jit | |
# import torch.nn.functional as F | |
# | |
# | |
# @torch.jit.script | |
# def create_window(window_size: int, sigma: float, channel: int): | |
# ''' | |
# Create 1-D gauss kernel | |
# :param window_size: the size of gauss kernel | |
# :param sigma: sigma of normal distribution | |
# :param channel: input channel | |
# :return: 1D kernel | |
# ''' | |
# coords = torch.arange(window_size, dtype=torch.float) | |
# coords -= window_size // 2 | |
# | |
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) | |
# g /= g.sum() | |
# | |
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) | |
# return g | |
# | |
# | |
# @torch.jit.script | |
# def _gaussian_filter(x, window_1d, use_padding: bool): | |
# ''' | |
# Blur input with 1-D kernel | |
# :param x: batch of tensors to be blured | |
# :param window_1d: 1-D gauss kernel | |
# :param use_padding: padding image before conv | |
# :return: blured tensors | |
# ''' | |
# C = x.shape[1] | |
# padding = 0 | |
# if use_padding: | |
# window_size = window_1d.shape[3] | |
# padding = window_size // 2 | |
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) | |
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) | |
# return out | |
# | |
# | |
# @torch.jit.script | |
# def ssim(X, Y, window, data_range: float, use_padding: bool = False): | |
# ''' | |
# Calculate ssim index for X and Y | |
# :param X: images [B, C, H, N_bins] | |
# :param Y: images [B, C, H, N_bins] | |
# :param window: 1-D gauss kernel | |
# :param data_range: value range of input images. (usually 1.0 or 255) | |
# :param use_padding: padding image before conv | |
# :return: | |
# ''' | |
# | |
# K1 = 0.01 | |
# K2 = 0.03 | |
# compensation = 1.0 | |
# | |
# C1 = (K1 * data_range) ** 2 | |
# C2 = (K2 * data_range) ** 2 | |
# | |
# mu1 = _gaussian_filter(X, window, use_padding) | |
# mu2 = _gaussian_filter(Y, window, use_padding) | |
# sigma1_sq = _gaussian_filter(X * X, window, use_padding) | |
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) | |
# sigma12 = _gaussian_filter(X * Y, window, use_padding) | |
# | |
# mu1_sq = mu1.pow(2) | |
# mu2_sq = mu2.pow(2) | |
# mu1_mu2 = mu1 * mu2 | |
# | |
# sigma1_sq = compensation * (sigma1_sq - mu1_sq) | |
# sigma2_sq = compensation * (sigma2_sq - mu2_sq) | |
# sigma12 = compensation * (sigma12 - mu1_mu2) | |
# | |
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) | |
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. | |
# cs_map = cs_map.clamp_min(0.) | |
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map | |
# | |
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW | |
# cs = cs_map.mean(dim=(1, 2, 3)) | |
# | |
# return ssim_val, cs | |
# | |
# | |
# @torch.jit.script | |
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8): | |
# ''' | |
# interface of ms-ssim | |
# :param X: a batch of images, (N,C,H,W) | |
# :param Y: a batch of images, (N,C,H,W) | |
# :param window: 1-D gauss kernel | |
# :param data_range: value range of input images. (usually 1.0 or 255) | |
# :param weights: weights for different levels | |
# :param use_padding: padding image before conv | |
# :param eps: use for avoid grad nan. | |
# :return: | |
# ''' | |
# levels = weights.shape[0] | |
# cs_vals = [] | |
# ssim_vals = [] | |
# for _ in range(levels): | |
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) | |
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. | |
# ssim_val = ssim_val.clamp_min(eps) | |
# cs = cs.clamp_min(eps) | |
# cs_vals.append(cs) | |
# | |
# ssim_vals.append(ssim_val) | |
# padding = (X.shape[2] % 2, X.shape[3] % 2) | |
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding) | |
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding) | |
# | |
# cs_vals = torch.stack(cs_vals, dim=0) | |
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0) | |
# return ms_ssim_val | |
# | |
# | |
# class SSIM(torch.jit.ScriptModule): | |
# __constants__ = ['data_range', 'use_padding'] | |
# | |
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): | |
# ''' | |
# :param window_size: the size of gauss kernel | |
# :param window_sigma: sigma of normal distribution | |
# :param data_range: value range of input images. (usually 1.0 or 255) | |
# :param channel: input channels (default: 3) | |
# :param use_padding: padding image before conv | |
# ''' | |
# super().__init__() | |
# assert window_size % 2 == 1, 'Window size must be odd.' | |
# window = create_window(window_size, window_sigma, channel) | |
# self.register_buffer('window', window) | |
# self.data_range = data_range | |
# self.use_padding = use_padding | |
# | |
# @torch.jit.script_method | |
# def forward(self, X, Y): | |
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) | |
# return r[0] | |
# | |
# | |
# class MS_SSIM(torch.jit.ScriptModule): | |
# __constants__ = ['data_range', 'use_padding', 'eps'] | |
# | |
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None, | |
# levels=None, eps=1e-8): | |
# ''' | |
# class for ms-ssim | |
# :param window_size: the size of gauss kernel | |
# :param window_sigma: sigma of normal distribution | |
# :param data_range: value range of input images. (usually 1.0 or 255) | |
# :param channel: input channels | |
# :param use_padding: padding image before conv | |
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) | |
# :param levels: number of downsampling | |
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. | |
# ''' | |
# super().__init__() | |
# assert window_size % 2 == 1, 'Window size must be odd.' | |
# self.data_range = data_range | |
# self.use_padding = use_padding | |
# self.eps = eps | |
# | |
# window = create_window(window_size, window_sigma, channel) | |
# self.register_buffer('window', window) | |
# | |
# if weights is None: | |
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] | |
# weights = torch.tensor(weights, dtype=torch.float) | |
# | |
# if levels is not None: | |
# weights = weights[:levels] | |
# weights = weights / weights.sum() | |
# | |
# self.register_buffer('weights', weights) | |
# | |
# @torch.jit.script_method | |
# def forward(self, X, Y): | |
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, | |
# use_padding=self.use_padding, eps=self.eps) | |
# | |
# | |
# if __name__ == '__main__': | |
# print('Simple Test') | |
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') | |
# img1 = im / 255 | |
# img2 = img1 * 0.5 | |
# | |
# losser = SSIM(data_range=1.).cuda() | |
# loss = losser(img1, img2).mean() | |
# | |
# losser2 = MS_SSIM(data_range=1.).cuda() | |
# loss2 = losser2(img1, img2).mean() | |
# | |
# print(loss.item()) | |
# print(loss2.item()) | |
# | |
# if __name__ == '__main__': | |
# print('Training Test') | |
# import cv2 | |
# import torch.optim | |
# import numpy as np | |
# import imageio | |
# import time | |
# | |
# out_test_video = False | |
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF | |
# video_use_gif = False | |
# | |
# im = cv2.imread('test_img1.jpg', 1) | |
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255. | |
# | |
# if out_test_video: | |
# if video_use_gif: | |
# fps = 0.5 | |
# out_wh = (im.shape[1] // 2, im.shape[0] // 2) | |
# suffix = '.gif' | |
# else: | |
# fps = 5 | |
# out_wh = (im.shape[1], im.shape[0]) | |
# suffix = '.mkv' | |
# video_last_time = time.perf_counter() | |
# video = imageio.get_writer('ssim_test' + suffix, fps=fps) | |
# | |
# # 测试ssim | |
# print('Training SSIM') | |
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. | |
# rand_im.requires_grad = True | |
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) | |
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda() | |
# ssim_score = 0 | |
# while ssim_score < 0.999: | |
# optim.zero_grad() | |
# loss = losser(rand_im, t_im) | |
# (-loss).sum().backward() | |
# ssim_score = loss.item() | |
# optim.step() | |
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] | |
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) | |
# | |
# if out_test_video: | |
# if time.perf_counter() - video_last_time > 1. / fps: | |
# video_last_time = time.perf_counter() | |
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) | |
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) | |
# if isinstance(out_frame, cv2.UMat): | |
# out_frame = out_frame.get() | |
# video.append_data(out_frame) | |
# | |
# cv2.imshow('ssim', r_im) | |
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score) | |
# cv2.waitKey(1) | |
# | |
# if out_test_video: | |
# video.close() | |
# | |
# # 测试ms_ssim | |
# if out_test_video: | |
# if video_use_gif: | |
# fps = 0.5 | |
# out_wh = (im.shape[1] // 2, im.shape[0] // 2) | |
# suffix = '.gif' | |
# else: | |
# fps = 5 | |
# out_wh = (im.shape[1], im.shape[0]) | |
# suffix = '.mkv' | |
# video_last_time = time.perf_counter() | |
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps) | |
# | |
# print('Training MS_SSIM') | |
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. | |
# rand_im.requires_grad = True | |
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) | |
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda() | |
# ssim_score = 0 | |
# while ssim_score < 0.999: | |
# optim.zero_grad() | |
# loss = losser(rand_im, t_im) | |
# (-loss).sum().backward() | |
# ssim_score = loss.item() | |
# optim.step() | |
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] | |
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) | |
# | |
# if out_test_video: | |
# if time.perf_counter() - video_last_time > 1. / fps: | |
# video_last_time = time.perf_counter() | |
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) | |
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) | |
# if isinstance(out_frame, cv2.UMat): | |
# out_frame = out_frame.get() | |
# video.append_data(out_frame) | |
# | |
# cv2.imshow('ms_ssim', r_im) | |
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score) | |
# cv2.waitKey(1) | |
# | |
# if out_test_video: | |
# video.close() | |
""" | |
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim | |
""" | |
import torch | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
from math import exp | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
return gauss / gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1 * mu2 | |
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
C1 = 0.01 ** 2 | |
C2 = 0.03 ** 2 | |
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1) | |
class SSIM(torch.nn.Module): | |
def __init__(self, window_size=11, size_average=True): | |
super(SSIM, self).__init__() | |
self.window_size = window_size | |
self.size_average = size_average | |
self.channel = 1 | |
self.window = create_window(window_size, self.channel) | |
def forward(self, img1, img2): | |
(_, channel, _, _) = img1.size() | |
if channel == self.channel and self.window.data.type() == img1.data.type(): | |
window = self.window | |
else: | |
window = create_window(self.window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
self.window = window | |
self.channel = channel | |
return _ssim(img1, img2, window, self.window_size, channel, self.size_average) | |
window = None | |
def ssim(img1, img2, window_size=11, size_average=True): | |
(_, channel, _, _) = img1.size() | |
global window | |
if window is None: | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return _ssim(img1, img2, window, window_size, channel, size_average) | |