| import os |
| import skimage |
| import argparse |
| import numpy as np |
| from tqdm import tqdm |
| from PIL import Image |
|
|
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms.functional as tf |
|
|
| from . import model |
|
|
|
|
| def load_iHarmony4_subset(dataset_dir, mode): |
| if not mode in ['train', 'test']: |
| print('Invalid mode: {0} for the dataset: {1}'.format(mode, dataset_dir)) |
| exit() |
|
|
| sample_names = [] |
| with open(os.path.join(dataset_dir, '{0}_{1}.txt'.format(dataset_dir.split('/')[-1], mode)), 'r') as f: |
| sample_names = [_.strip() for _ in f.readlines()] |
|
|
| comp_dir = os.path.join(dataset_dir, 'composite_images') |
| mask_dir = os.path.join(dataset_dir, 'masks') |
| real_dir = os.path.join(dataset_dir, 'real_images') |
|
|
| samples = [] |
| comp_names = os.listdir(comp_dir) |
| for comp_name in comp_names: |
| if comp_name in sample_names: |
| mask_name = '_'.join(comp_name.split('_')[:-1]) + '.png' |
| real_name = '_'.join(comp_name.split('_')[:-2]) + '.jpg' |
|
|
| sample = { |
| 'comp': os.path.join(comp_dir, comp_name), |
| 'mask': os.path.join(mask_dir, mask_name), |
| 'real': os.path.join(real_dir, real_name), |
| } |
|
|
| samples.append(sample) |
|
|
| return samples |
|
|
|
|
| def calc_metrics(pred, gt, mask): |
| n, c, h, w = pred.shape |
| assert n == 1 |
| total_pixels = h * w |
| fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy()) |
|
|
| pred = torch.clamp(pred * 255, 0, 255) |
| gt = torch.clamp(gt * 255, 0, 255) |
|
|
| pred = pred[0].permute(1, 2, 0).cpu().numpy() |
| gt = gt[0].permute(1, 2, 0).cpu().numpy() |
| mask = mask[0].permute(1, 2, 0).cpu().numpy() |
|
|
| mse = skimage.metrics.mean_squared_error(pred, gt) |
| fmse = skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels |
| psnr = skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min()) |
| ssim = skimage.metrics.structural_similarity(pred, gt, multichannel=True) |
|
|
| return mse, fmse, psnr, ssim |
|
|
|
|
| if __name__ == '__main__': |
| |
| DATASET_DIR = './dataset' |
| if not os.path.exists(DATASET_DIR): |
| print('Cannot find the dataset dir') |
| exit() |
|
|
| |
| DATASETS = { |
| 'HCOCO': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HCOCO'), |
| 'HFlickr': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/HFlickr'), |
| 'Hday2night': os.path.join(DATASET_DIR, 'harmonization/iHarmony4/Hday2night'), |
| } |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--pretrained', type=str, default='./pretrained/harmonizer.pth', help='') |
| parser.add_argument('--datasets', type=str, nargs='+', required=True, choices=DATASETS.keys(), help='') |
| parser.add_argument('--metric-size', type=int, default=0, help='') |
| args = parser.parse_known_args()[0] |
|
|
| |
| metric_size = (args.metric_size, args.metric_size) if args.metric_size > 0 else None |
| cuda = torch.cuda.is_available() |
|
|
| |
| print('\n') |
| print('Evaluation Harmonizer:') |
| print(' - Pretrained Model: {0}'.format(args.pretrained)) |
| print(' - Validation Datasets: {0}'.format(args.datasets)) |
| print(' - Metric Calculation Size: {0}'.format(metric_size if args.metric_size > 0 else 'original')) |
|
|
| |
| harmonizer = model.Harmonizer() |
| if cuda: |
| harmonizer = harmonizer.cuda() |
| harmonizer.load_state_dict(torch.load(args.pretrained), strict=True) |
| harmonizer.eval() |
|
|
| |
| datasets = {} |
| for d in args.datasets: |
| datasets[d] = load_iHarmony4_subset(DATASETS[d], 'test') |
|
|
| |
| metrics = {} |
| for dkey, dvalue in datasets.items(): |
| print('\n') |
| print('================================================================================') |
| print('Validation Dataset: {0}'.format(dkey)) |
| print('--------------------------------------------------------------------------------') |
| metric = {'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0} |
| sample_num = len(dvalue) |
| pbar = tqdm(dvalue, total=sample_num, unit='sample') |
|
|
| for i, sample in enumerate(pbar): |
| |
| comp = Image.open(sample['comp']).convert('RGB') |
| mask = Image.open(sample['mask']).convert('1') |
| image = Image.open(sample['real']).convert('RGB') |
|
|
| |
| _comp = tf.to_tensor(comp)[None, ...] |
| _mask = tf.to_tensor(mask)[None, ...] |
| _image = tf.to_tensor(image)[None, ...] |
| if cuda: |
| _comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda() |
|
|
| |
| with torch.no_grad(): |
| arguments = harmonizer.predict_arguments(_comp, _mask) |
|
|
| |
| if metric_size is not None: |
| _comp = tf.to_tensor(tf.resize(comp, metric_size))[None, ...] |
| _mask = tf.to_tensor(tf.resize(mask, metric_size))[None, ...] |
| _image = tf.to_tensor(tf.resize(image, metric_size))[None, ...] |
| if cuda: |
| _comp, _mask, _image = _comp.cuda(), _mask.cuda(), _image.cuda() |
|
|
| with torch.no_grad(): |
| _harmonized = harmonizer.restore_image(_comp, _mask, arguments)[-1] |
|
|
| |
| mse, fmse, psnr, ssim = calc_metrics(_harmonized, _image, _mask) |
| |
| metric['MSE'] += mse |
| metric['fMSE'] += fmse |
| metric['PSNR'] += psnr |
| metric['SSIM'] += ssim |
| pbar.set_description('MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format( |
| metric['MSE']/(i+1), metric['fMSE']/(i+1), metric['PSNR']/(i+1), metric['SSIM']/(i+1))) |
| |
| print('--------------------------------------------------------------------------------') |
| print('{0} - MSE: {1:.4f} fMSE: {2:.4f} PSNR: {3:.4f} SSIM: {4:.4f}'.format( |
| dkey, metric['MSE']/sample_num, metric['fMSE']/sample_num, metric['PSNR']/sample_num, metric['SSIM']/sample_num)) |
| print('================================================================================') |
|
|
| metrics[dkey] = metric |
|
|
| sample_num = sum([len(dvalue) for dvalue in datasets.values()]) |
| mse = sum([metric['MSE'] for metric in metrics.values()]) / sample_num |
| fmse = sum([metric['fMSE'] for metric in metrics.values()]) / sample_num |
| psnr = sum([metric['PSNR'] for metric in metrics.values()]) / sample_num |
| ssim = sum([metric['SSIM'] for metric in metrics.values()]) / sample_num |
|
|
| print('\n') |
| print('================================================================================') |
| print('All - MSE: {0:.4f} fMSE: {1:.4f} PSNR: {2:.4f} SSIM: {3:.4f}'.format(mse, fmse, psnr, ssim)) |
| print('================================================================================') |
| print('\n') |
|
|