import argparse import json import os import math from functools import partial import seaborn as sns import cv2.dnn import numpy as np import yaml import torch from einops import rearrange from matplotlib import pyplot as plt from torch.utils.data import DataLoader from tqdm import tqdm import datasets import models import utils device = 'cuda:0' if torch.cuda.is_available() else 'cpu' def batched_predict(model, img, coord, bsize): with torch.no_grad(): pred = model(img, coord) return pred def eval_psnr(loader, class_names, model, data_norm=None, eval_type=None, save_fig=False, save_featmap=False, scale_ratio=1, save_path=None, verbose=False, crop_border=4, cal_metrics=True, ): crop_border = int(crop_border) if crop_border else crop_border print('crop border: ', crop_border) model.eval() if data_norm is None: data_norm = { 'img': {'sub': [0], 'div': [1]}, 'gt': {'sub': [0], 'div': [1]} } t = data_norm['img'] img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) t = data_norm['gt'] gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) if eval_type is None: metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] elif eval_type == 'psnr+ssim': metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] elif eval_type.startswith('div2k'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) elif eval_type.startswith('benchmark'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) else: raise NotImplementedError val_res_psnr = utils.Averager(class_names) val_res_ssim = utils.Averager(class_names) pbar = tqdm(loader, leave=False, desc='val') for batch in pbar: for k, v in batch.items(): if torch.is_tensor(v): batch[k] = v.to(device) img = (batch['img'] - img_sub) / img_div with torch.no_grad(): preds = model(img, batch['gt'].shape[-2:]) if save_featmap: pred = preds[0][-1] returned_featmap = preds[1] assert returned_featmap.size(1) == 6 else: if isinstance(preds, list): pred = preds[-1] # import pdb # pdb.set_trace() pred = pred * gt_div + gt_sub # if eval_type is not None: # reshape for shaving-eval # ih, iw = batch['img'].shape[-2:] # s = math.sqrt(batch['coord'].shape[1] / (ih * iw)) # if s > 1: # shape = [batch['img'].shape[0], round(ih * s), round(iw * s), 3] # else: # shape = [batch['img'].shape[0], 32, batch['coord'].shape[1]//32, 3] # # pred = pred.view(*shape) \ # .permute(0, 3, 1, 2).contiguous() # batch['gt'] = batch['gt'].view(*shape) \ # .permute(0, 3, 1, 2).contiguous() # if crop_border is not None: # h = math.sqrt(pred.shape[1]) # shape = [img.shape[0], round(h), round(h), 3] # pred = pred.view(*shape).permute(0, 3, 1, 2).contiguous() # batch['gt'] = batch['gt'].view(*shape).permute(0, 3, 1, 2).contiguous() # else: # pred = pred.permute(0, 2, 1).contiguous() # B 3 N # batch['gt'] = batch['gt'].permute(0, 2, 1).contiguous() # import pdb # pdb.set_trace() if cal_metrics: res_psnr = metric_fn[0]( pred, batch['gt'], crop_border=crop_border ) res_ssim = metric_fn[1]( pred, batch['gt'], crop_border=crop_border ) else: res_psnr = torch.ones(len(pred)) res_ssim = torch.ones(len(pred)) file_names = batch.get('filename', None) if file_names is not None and save_featmap: for idx in range(len(batch['img'])): ori_img = batch['img'][idx].cpu().numpy() * 255 ori_img = np.clip(ori_img, a_min=0, a_max=255) ori_img = ori_img.astype(np.uint8) ori_img = rearrange(ori_img, 'C H W -> H W C') pred_img = pred[idx].cpu().numpy() * 255 pred_img = np.clip(pred_img, a_min=0, a_max=255) pred_img = pred_img.astype(np.uint8) pred_img = rearrange(pred_img, 'C H W -> H W C') is_normalize = True f_tensors = returned_featmap[idx] for idx_f in range(len(f_tensors)): f_tensor = f_tensors[idx_f] if is_normalize: # normalize the features / feature maps f_tensor = torch.sigmoid(f_tensor) f_tensor = f_tensor.detach().cpu().numpy() # for better visualization, you can normalize the feature heatmap f_tensor = (f_tensor - np.min(f_tensor)) / (np.max(f_tensor) - np.min(f_tensor)) # f_tensor = (f_tensor - np.min(f_tensor)) / (np.max(f_tensor) - np.min(f_tensor)) sns.heatmap(f_tensor, vmin=0, vmax=1, cmap="jet", center=0.5) plt.axis('off') plt.xticks([]) plt.yticks([]) # plt.imshow(heatmap, cmap='YlGnBu', vmin=0, vmax=1) # plt.show() ori_file_name = f'{save_path}/{file_names[idx]}_{idx_f}.png' plt.savefig(ori_file_name, dpi=600) plt.close() gt_img = batch['gt'][idx].cpu().numpy() * 255 gt_img = np.clip(gt_img, a_min=0, a_max=255) gt_img = gt_img.astype(np.uint8) gt_img = rearrange(gt_img, 'C H W -> H W C') psnr = res_psnr[idx].cpu().numpy() ssim = res_ssim[idx].cpu().numpy() ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' cv2.imwrite(ori_file_name, ori_img) pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' cv2.imwrite(pred_file_name, pred_img) gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' cv2.imwrite(gt_file_name, gt_img) # import pdb # pdb.set_trace() if file_names is not None and save_fig: for idx in range(len(batch['img'])): ori_img = batch['img'][idx].cpu().numpy() * 255 ori_img = np.clip(ori_img, a_min=0, a_max=255) ori_img = ori_img.astype(np.uint8) ori_img = rearrange(ori_img, 'C H W -> H W C') pred_img = pred[idx].cpu().numpy() * 255 pred_img = np.clip(pred_img, a_min=0, a_max=255) pred_img = pred_img.astype(np.uint8) pred_img = rearrange(pred_img, 'C H W -> H W C') gt_img = batch['gt'][idx].cpu().numpy() * 255 gt_img = np.clip(gt_img, a_min=0, a_max=255) gt_img = gt_img.astype(np.uint8) gt_img = rearrange(gt_img, 'C H W -> H W C') psnr = res_psnr[idx].cpu().numpy() ssim = res_ssim[idx].cpu().numpy() ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' cv2.imwrite(ori_file_name, ori_img) pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' cv2.imwrite(pred_file_name, pred_img) gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' cv2.imwrite(gt_file_name, gt_img) val_res_psnr.add(batch['class_name'], res_psnr) val_res_ssim.add(batch['class_name'], res_ssim) if verbose: pbar.set_description( 'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.item()['all'], val_res_ssim.item()['all'])) return val_res_psnr.item(), val_res_ssim.item() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', default='configs/test_UC_INR_mysr.yaml') parser.add_argument('--model', default='checkpoints/EXP20220610_5/epoch-best.pth') parser.add_argument('--scale_ratio', default=4, type=float) parser.add_argument('--save_fig', default=False, type=bool) parser.add_argument('--save_featmap', default=False, type=bool) parser.add_argument('--save_path', default='tmp', type=str) parser.add_argument('--cal_metrics', default=True, type=bool) parser.add_argument('--return_class_metrics', default=False, type=bool) parser.add_argument('--dataset_name', default='UC', type=str) args = parser.parse_args() with open(args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) root_split_file = {'UC': { 'root_path': '/Users/kyanchen/Documents/UC/256', 'split_file': '/Users/kyanchen/My_Code/sr/data_split/UC_split.json' }, 'AID': { 'root_path': '/data/kyanchen/datasets/AID', 'split_file': 'data_split/AID_split.json' } } config['test_dataset']['dataset']['args']['root_path'] = root_split_file[args.dataset_name]['root_path'] config['test_dataset']['dataset']['args']['split_file'] = root_split_file[args.dataset_name]['split_file'] config['test_dataset']['wrapper']['args']['scale_ratio'] = args.scale_ratio spec = config['test_dataset'] dataset = datasets.make(spec['dataset']) dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) loader = DataLoader(dataset, batch_size=spec['batch_size'], num_workers=0, pin_memory=True, shuffle=False, drop_last=False) if not os.path.exists(args.model): assert NameError model_spec = torch.load(args.model, map_location='cpu')['model'] print(model_spec['args']) model = models.make(model_spec, load_sd=True).to(device) file_names = json.load(open(config['test_dataset']['dataset']['args']['split_file']))['test'] class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) crop_border = config['test_dataset']['wrapper']['args']['scale_ratio'] + 5 dataset_name = os.path.basename(config['test_dataset']['dataset']['args']['split_file']).split('_')[0] max_scale = {'UC': 5, 'AID': 12} if args.scale_ratio > max_scale[dataset_name]: crop_border = int((args.scale_ratio - max_scale[dataset_name]) / 2 * 48) if args.save_fig or args.save_featmap: os.makedirs(args.save_path, exist_ok=True) res = eval_psnr( loader, class_names, model, data_norm=config.get('data_norm'), eval_type=config.get('eval_type'), crop_border=crop_border, verbose=True, save_fig=args.save_fig, save_featmap=args.save_featmap, scale_ratio=args.scale_ratio, save_path=args.save_path, cal_metrics=args.cal_metrics ) if args.return_class_metrics: keys = list(res[0].keys()) keys.sort() print('psnr') for k in keys: print(f'{k}: {res[0][k]:0.2f}') print('ssim') for k in keys: print(f'{k}: {res[1][k]:0.4f}') print(f'psnr: {res[0]["all"]:0.2f}') print(f'ssim: {res[1]["all"]:0.4f}')