Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
import os | |
import json | |
import sys | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as transforms | |
sys.path.append(".") | |
sys.path.append("..") | |
from criteria.lpips.lpips import LPIPS | |
from datasets.gt_res_dataset import GTResDataset | |
def parse_args(): | |
parser = ArgumentParser(add_help=False) | |
parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) | |
parser.add_argument('--data_path', type=str, default='results') | |
parser.add_argument('--gt_path', type=str, default='gt_images') | |
parser.add_argument('--workers', type=int, default=4) | |
parser.add_argument('--batch_size', type=int, default=4) | |
parser.add_argument('--is_cars', action='store_true') | |
args = parser.parse_args() | |
return args | |
def run(args): | |
resize_dims = (256, 256) | |
if args.is_cars: | |
resize_dims = (192, 256) | |
transform = transforms.Compose([transforms.Resize(resize_dims), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
print('Loading dataset') | |
dataset = GTResDataset(root_path=args.data_path, | |
gt_dir=args.gt_path, | |
transform=transform) | |
dataloader = DataLoader(dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
num_workers=int(args.workers), | |
drop_last=True) | |
if args.mode == 'lpips': | |
loss_func = LPIPS(net_type='alex') | |
elif args.mode == 'l2': | |
loss_func = torch.nn.MSELoss() | |
else: | |
raise Exception('Not a valid mode!') | |
loss_func.cuda() | |
global_i = 0 | |
scores_dict = {} | |
all_scores = [] | |
for result_batch, gt_batch in tqdm(dataloader): | |
for i in range(args.batch_size): | |
loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) | |
all_scores.append(loss) | |
im_path = dataset.pairs[global_i][0] | |
scores_dict[os.path.basename(im_path)] = loss | |
global_i += 1 | |
all_scores = list(scores_dict.values()) | |
mean = np.mean(all_scores) | |
std = np.std(all_scores) | |
result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) | |
print('Finished with ', args.data_path) | |
print(result_str) | |
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') | |
if not os.path.exists(out_path): | |
os.makedirs(out_path) | |
with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: | |
f.write(result_str) | |
with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: | |
json.dump(scores_dict, f) | |
if __name__ == '__main__': | |
args = parse_args() | |
run(args) | |