ReNoise-Inversion / src /metric_util.py
garibida's picture
Upload Files
d65c9b3
import torch
from PIL import Image
from torchvision import transforms
from src.lpips import LPIPS
import torch.nn as nn
dev = 'cuda'
to_tensor_transform = transforms.Compose([transforms.ToTensor()])
mse_loss = nn.MSELoss()
def calculate_l2_difference(image1, image2, device = 'cuda'):
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
mse = mse_loss(image1, image2).item()
return mse
def calculate_psnr(image1, image2, device = 'cuda'):
max_value = 1.0
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
mse = mse_loss(image1, image2)
psnr = 10 * torch.log10(max_value**2 / mse).item()
return psnr
loss_fn = LPIPS(net_type='vgg').to(dev).eval()
def calculate_lpips(image1, image2, device = 'cuda'):
if isinstance(image1, Image.Image):
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = to_tensor_transform(image2).to(device)
loss = loss_fn(image1, image2).item()
return loss
def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)):
if isinstance(image1, Image.Image):
image1 = image1.resize(size)
image1 = to_tensor_transform(image1).to(device)
if isinstance(image2, Image.Image):
image2 = image2.resize(size)
image2 = to_tensor_transform(image2).to(device)
l2 = calculate_l2_difference(image1, image2, device)
psnr = calculate_psnr(image1, image2, device)
lpips = calculate_lpips(image1, image2, device)
return {"l2": l2, "psnr": psnr, "lpips": lpips}
def get_empty_metrics():
return {"l2": 0, "psnr": 0, "lpips": 0}
def print_results(results):
print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}")