test / tools /eval_metrics.py
Tu Bui
first commit
6142a25
import torch
import numpy as np
import skimage.metrics
import lpips
from PIL import Image
from .sifid import SIFID
def resize_array(x, size=256):
"""
Resize image array to given size.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
size (int): Size of output image.
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
if x.shape[1] != size:
x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])]
x = np.array([np.array(i) for i in x])
return x
def resize_tensor(x, size=256):
"""
Resize image tensor to given size.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
size (int): Size of output image.
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
if x.shape[2] != size:
x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
return x
def normalise(x):
"""
Normalise image array to range [-1, 1] and tensor.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
x = x.astype(np.float32)
x = x / 255
x = (x - 0.5) / 0.5
x = torch.from_numpy(x)
x = x.permute(0, 3, 1, 2)
return x
def unormalise(x, vrange=[-1, 1]):
"""
Unormalise image tensor to range [0, 255] and RGB array.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
x = (x - vrange[0])/(vrange[1] - vrange[0])
x = x * 255
x = x.permute(0, 2, 3, 1)
x = x.cpu().numpy().astype(np.uint8)
return x
def compute_mse(x, y):
"""
Compute mean squared error between two image arrays.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(1darray): Mean squared error.
"""
return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1)
def compute_psnr(x, y):
"""
Compute peak signal-to-noise ratio between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Peak signal-to-noise ratio.
"""
return 10 * np.log10(255 ** 2 / compute_mse(x, y))
def compute_ssim(x, y):
"""
Compute structural similarity index between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Structural similarity index.
"""
return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)])
def compute_lpips(x, y, net='alex'):
"""
Compute LPIPS between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): LPIPS.
"""
lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
x, y = x.cuda(), y.cuda()
return lpips_fn(x, y).detach().cpu().numpy().squeeze()
def compute_sifid(x, y, net=None):
"""
Compute SIFID between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): SIFID.
"""
fn = SIFID() if net is None else net
out = [fn(xi, yi) for xi, yi in zip(x, y)]
return np.array(out)