File size: 3,046 Bytes
9e6cbab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np

import torch
from torchvision import transforms

import torch.nn.functional as F

from torch.autograd.variable import Variable

NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)

def normalize_img(x):
    return (x.to(device) - image_mean) / image_std

def unnormalize_img(x):
    return (x.to(device) * image_std) + image_mean

def round_pixel(x):
    x_pixel = 255 * unnormalize_img(x)
    y = torch.round(x_pixel).clamp(0, 255)
    y = normalize_img(y/255.0)
    return y

def project_linf(x, y, radius):
    """ Clamp x-y so that Linf(x,y)<=radius """
    delta = x - y
    delta = 255 * (delta * image_std)
    delta = torch.clamp(delta, -radius, radius)
    delta = (delta / 255.0) / image_std
    return y + delta

def psnr_clip(x, y, target_psnr):
    """ Clip x-y so that PSNR(x,y)=target_psnr """
    delta = x - y
    delta = 255 * (delta * image_std)
    psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
    if psnr<target_psnr:
        delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta 
    psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
    delta = (delta / 255.0) / image_std
    return y + delta

def ssim_heatmap(img1, img2, window_size):
    """ Compute the SSIM heatmap between 2 images """
    _1D_window = torch.Tensor(
        [np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
        ).to(device, non_blocking=True)
    _1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())

    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
    return ssim_map

def ssim_attenuation(x, y):
    """ attenuate x-y using SSIM heatmap """
    delta = x - y
    ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
    ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
    ssim_map = torch.clamp_min(ssim_map,0)
    # min_v = torch.min(ssim_map)
    # range_v = torch.max(ssim_map) - min_v
    # if range_v < 1e-10:
    #     return y + delta
    # ssim_map = (ssim_map - min_v) / range_v
    delta = delta*ssim_map
    return y + delta