Spaces:
Runtime error
Runtime error
from unittest import result | |
from matplotlib.pyplot import hist | |
from torch.utils import data | |
from torch.utils.data.dataset import Dataset | |
import os,torch | |
from PIL import Image | |
import torchvision.transforms as T | |
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM | |
import torch.nn.functional as F | |
from imaginaire.evaluation.segmentation import get_segmentation_hist_model,get_miou,compute_hist | |
import lpips | |
from easydict import EasyDict as edict | |
from tqdm import tqdm | |
import piq | |
from torch.utils.data import DataLoader | |
from piq import FID,KID | |
import numpy as np | |
result_path = 'result/Ours-pers-sin-sty' | |
gt_path = 'dataset/CVACT/streetview_test' | |
class Dataset_img(Dataset): | |
def __init__(self, dir): | |
self.dir = dir | |
self.datalist = sorted(os.listdir(dir)) | |
def __len__(self): | |
return len(self.datalist) | |
def __getitem__(self, index): | |
img = os.path.join(self.dir,self.datalist[index]) | |
img = Image.open(img).convert('RGB') | |
img = T.ToTensor()(img) | |
return {'images':img} | |
data_gt = Dataset_img(gt_path) | |
data_pred = Dataset_img(result_path) | |
loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda() | |
loss_fn_squeeze = lpips.LPIPS(net='squeeze',eval_mode=True).cuda() | |
data_list = os.listdir(result_path) | |
results = edict() | |
results.psnr = [] | |
results.ssim = [] | |
results.alex = [] | |
results.squeeze = [] | |
results.RMSE = [] | |
dataloader_pred = DataLoader(data_pred,batch_size=1,shuffle=False,num_workers=10) | |
dataloader_gt = DataLoader(data_gt,batch_size=1,shuffle=False,num_workers=10) | |
for i in tqdm(zip(dataloader_pred,dataloader_gt),ncols=100): | |
pred = i[0]['images'].cuda() | |
gt = i[1]['images'].cuda() | |
results.psnr.append(-10*F.mse_loss(pred,gt).log10().item()) | |
results.ssim.append(ssim(pred, gt,data_range=1.).item()) | |
results.alex.append(torch.mean(loss_fn_alex((pred*2.)-1, (2.*gt)-1)).cpu().item()) | |
results.squeeze.append(torch.mean(loss_fn_squeeze((pred*2.)-1, (2.*gt)-1)).cpu().item()) | |
results.RMSE.append(torch.sqrt(F.mse_loss(pred,gt)).item()*255) | |
for i in results: | |
print("%-10s"%i, ':',np.mean(results[i])) | |