test / tools /eval_metrics.py
Tu Bui
first commit
6142a25
raw
history blame
No virus
4.04 kB
import torch
import numpy as np
import skimage.metrics
import lpips
from PIL import Image
from .sifid import SIFID
def resize_array(x, size=256):
"""
Resize image array to given size.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
size (int): Size of output image.
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
if x.shape[1] != size:
x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])]
x = np.array([np.array(i) for i in x])
return x
def resize_tensor(x, size=256):
"""
Resize image tensor to given size.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
size (int): Size of output image.
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
if x.shape[2] != size:
x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
return x
def normalise(x):
"""
Normalise image array to range [-1, 1] and tensor.
Args:
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
Returns:
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
"""
x = x.astype(np.float32)
x = x / 255
x = (x - 0.5) / 0.5
x = torch.from_numpy(x)
x = x.permute(0, 3, 1, 2)
return x
def unormalise(x, vrange=[-1, 1]):
"""
Unormalise image tensor to range [0, 255] and RGB array.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
"""
x = (x - vrange[0])/(vrange[1] - vrange[0])
x = x * 255
x = x.permute(0, 2, 3, 1)
x = x.cpu().numpy().astype(np.uint8)
return x
def compute_mse(x, y):
"""
Compute mean squared error between two image arrays.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(1darray): Mean squared error.
"""
return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1)
def compute_psnr(x, y):
"""
Compute peak signal-to-noise ratio between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Peak signal-to-noise ratio.
"""
return 10 * np.log10(255 ** 2 / compute_mse(x, y))
def compute_ssim(x, y):
"""
Compute structural similarity index between two images.
Args:
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
Returns:
(float): Structural similarity index.
"""
return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)])
def compute_lpips(x, y, net='alex'):
"""
Compute LPIPS between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): LPIPS.
"""
lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
x, y = x.cuda(), y.cuda()
return lpips_fn(x, y).detach().cpu().numpy().squeeze()
def compute_sifid(x, y, net=None):
"""
Compute SIFID between two images.
Args:
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
Returns:
(float): SIFID.
"""
fn = SIFID() if net is None else net
out = [fn(xi, yi) for xi, yi in zip(x, y)]
return np.array(out)