from argparse import ArgumentParser import time import numpy as np import os import json import sys from PIL import Image import multiprocessing as mp import math import torch import torchvision.transforms as trans sys.path.append(".") sys.path.append("..") from models.mtcnn.mtcnn import MTCNN from models.encoders.model_irse import IR_101 from configs.paths_config import model_paths CIRCULAR_FACE_PATH = model_paths['circular_face'] def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] def extract_on_paths(file_paths): facenet = IR_101(input_size=112) facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) facenet.cuda() facenet.eval() mtcnn = MTCNN() id_transform = trans.Compose([ trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) pid = mp.current_process().name print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) tot_count = len(file_paths) count = 0 scores_dict = {} for res_path, gt_path in file_paths: count += 1 if count % 100 == 0: print('{} done with {}/{}'.format(pid, count, tot_count)) if True: input_im = Image.open(res_path) input_im, _ = mtcnn.align(input_im) if input_im is None: print('{} skipping {}'.format(pid, res_path)) continue input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] result_im = Image.open(gt_path) result_im, _ = mtcnn.align(result_im) if result_im is None: print('{} skipping {}'.format(pid, gt_path)) continue result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] score = float(input_id.dot(result_id)) scores_dict[os.path.basename(gt_path)] = score return scores_dict def parse_args(): parser = ArgumentParser(add_help=False) parser.add_argument('--num_threads', type=int, default=4) parser.add_argument('--data_path', type=str, default='results') parser.add_argument('--gt_path', type=str, default='gt_images') args = parser.parse_args() return args def run(args): file_paths = [] for f in os.listdir(args.data_path): image_path = os.path.join(args.data_path, f) gt_path = os.path.join(args.gt_path, f) if f.endswith(".jpg") or f.endswith('.png'): file_paths.append([image_path, gt_path.replace('.png','.jpg')]) file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) pool = mp.Pool(args.num_threads) print('Running on {} paths\nHere we goooo'.format(len(file_paths))) tic = time.time() results = pool.map(extract_on_paths, file_chunks) scores_dict = {} for d in results: scores_dict.update(d) all_scores = list(scores_dict.values()) mean = np.mean(all_scores) std = np.std(all_scores) result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) print(result_str) out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') if not os.path.exists(out_path): os.makedirs(out_path) with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: f.write(result_str) with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: json.dump(scores_dict, f) toc = time.time() print('Mischief managed in {}s'.format(toc - tic)) if __name__ == '__main__': args = parse_args() run(args)