Spaces:
Runtime error
Runtime error
import os | |
import ssl | |
from os.path import join | |
from pathlib import Path | |
from statistics import mean | |
parent_path = Path(__file__).absolute().parent.parent | |
parent_path = os.path.abspath(parent_path) | |
os.environ["CURL_CA_BUNDLE"] = "" | |
ssl._create_default_https_context = ssl._create_unverified_context | |
cache_path = os.path.join(parent_path, 'cache') | |
os.environ["HF_DATASETS_CACHE"] = cache_path | |
os.environ["TRANSFORMERS_CACHE"] = cache_path | |
os.environ["torch_HOME"] = cache_path | |
import PIL | |
import numpy as np | |
import pandas as pd | |
import pyiqa | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
metric_dict = { | |
'psnr-Y': pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr'), | |
'ssim': pyiqa.create_metric('ssim', color_space='ycbcr'), | |
'fid': pyiqa.create_metric('fid'), | |
} | |
def load_img(path, target_size=None): | |
image = Image.open(path).convert("RGB") | |
if target_size: | |
h, w = target_size | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
return image | |
def eval_img_IQA(gt_dir, sr_dir, excel_path, metric_list, exp_name, data_name): | |
gt_img_list = os.listdir(gt_dir) | |
iqa_result = {} | |
for metric in metric_list: | |
iqa_metric = metric_dict[metric].to(device) | |
score_fr_list = [] | |
if metric == 'fid': | |
score_fr = iqa_metric(sr_dir, gt_dir) | |
iqa_result[metric] = float(score_fr) | |
print(f'{metric}: {float(score_fr)}') | |
else: | |
for img_name in tqdm(gt_img_list): | |
base_name = img_name.split('.')[0] | |
sr_img_name = f'{base_name}.png' | |
gt_img_path = join(gt_dir, img_name) | |
sr_img_path = join(sr_dir, sr_img_name) | |
if not os.path.exists(sr_img_path): | |
print(f'File not exist: {sr_img_path}') | |
continue | |
gt_img = load_img(gt_img_path, target_size=None) | |
target_size = gt_img.shape[2:] | |
sr_img = load_img(sr_img_path, target_size=target_size) | |
score_fr = iqa_metric(sr_img, gt_img) | |
if score_fr.shape == (1,): | |
score_fr = score_fr[0] | |
if isinstance(score_fr, torch.Tensor): | |
score_fr = float(score_fr.cpu().numpy()) | |
else: | |
score_fr = float(score_fr) | |
score_fr_list.append(score_fr) | |
mean_score = mean(score_fr_list) | |
iqa_result[metric] = float(mean_score) | |
print(f'{metric}: {mean_score}') | |
if os.path.exists(excel_path): | |
df = pd.read_excel(excel_path) | |
else: | |
df = pd.DataFrame(columns=['exp']) | |
new_index = len(df.index) | |
exp_name = int(exp_name) | |
if exp_name in df['exp'].to_list(): | |
new_index = df[df['exp'] == exp_name].index.tolist()[0] | |
else: | |
df.loc[new_index, 'exp'] = exp_name | |
for index, metric in enumerate(metric_list): | |
df_metric = f'{data_name}-{metric}' | |
if df_metric not in df.columns.tolist(): | |
df[df_metric] = '' | |
df.loc[new_index, df_metric] = iqa_result[metric] | |
df.sort_values(by='exp', inplace=True) | |
df.to_excel(excel_path, startcol=0, index=False) | |
def main(): | |
epoch = 400000 | |
add_name = '' | |
exp_root = '/home/ma-user/work/code/SRDiff-main/checkpoints' | |
model_type_list = ['diffsr_df2k4x_sam-pl_qs-zero'] | |
metric_list = ['psnr-Y', 'ssim', 'fid'] | |
benchmark_name_list = ['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100'] | |
# if benchmark: | |
for model_type in model_type_list: | |
excel_path = join(exp_root, model_type, f'IQA-val-{model_type}.xls') | |
for benchmark_name in benchmark_name_list: | |
exp_dir = join(exp_root, f'{model_type}/results_{epoch}_{add_name}/benchmark/{benchmark_name}') | |
gt_img_dir = join(exp_dir, 'HR') | |
sr_img_dir = join(exp_dir, 'SR') | |
data_name = benchmark_name[5:] | |
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name) | |
if __name__ == '__main__': | |
main() | |