File size: 4,040 Bytes
6142a25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import numpy as np
import skimage.metrics
import lpips
from PIL import Image
from .sifid import SIFID
def resize_array(x, size=256):
"""
Resize image array to given size.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
size (int): Size of output image.
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
if x.shape[1] != size:
x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])]
x = np.array([np.array(i) for i in x])
return x
def resize_tensor(x, size=256):
"""
Resize image tensor to given size.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
size (int): Size of output image.
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
if x.shape[2] != size:
x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
return x
def normalise(x):
"""
Normalise image array to range [-1, 1] and tensor.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
x = x.astype(np.float32)
x = x / 255
x = (x - 0.5) / 0.5
x = torch.from_numpy(x)
x = x.permute(0, 3, 1, 2)
return x
def unormalise(x, vrange=[-1, 1]):
"""
Unormalise image tensor to range [0, 255] and RGB array.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
x = (x - vrange[0])/(vrange[1] - vrange[0])
x = x * 255
x = x.permute(0, 2, 3, 1)
x = x.cpu().numpy().astype(np.uint8)
return x
def compute_mse(x, y):
"""
Compute mean squared error between two image arrays.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(1darray): Mean squared error.
"""
return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1)
def compute_psnr(x, y):
"""
Compute peak signal-to-noise ratio between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Peak signal-to-noise ratio.
"""
return 10 * np.log10(255 ** 2 / compute_mse(x, y))
def compute_ssim(x, y):
"""
Compute structural similarity index between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Structural similarity index.
"""
return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)])
def compute_lpips(x, y, net='alex'):
"""
Compute LPIPS between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): LPIPS.
"""
lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
x, y = x.cuda(), y.cuda()
return lpips_fn(x, y).detach().cpu().numpy().squeeze()
def compute_sifid(x, y, net=None):
"""
Compute SIFID between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): SIFID.
"""
fn = SIFID() if net is None else net
out = [fn(xi, yi) for xi, yi in zip(x, y)]
return np.array(out) |