radames's picture
add inversion
d9778ff
raw
history blame
No virus
3.39 kB
from argparse import ArgumentParser
import time
import numpy as np
import os
import json
import sys
from PIL import Image
import multiprocessing as mp
import math
import torch
import torchvision.transforms as trans
sys.path.append(".")
sys.path.append("..")
from models.mtcnn.mtcnn import MTCNN
from models.encoders.model_irse import IR_101
from configs.paths_config import model_paths
CIRCULAR_FACE_PATH = model_paths['circular_face']
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def extract_on_paths(file_paths):
facenet = IR_101(input_size=112)
facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH))
facenet.cuda()
facenet.eval()
mtcnn = MTCNN()
id_transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
pid = mp.current_process().name
print('\t{} is starting to extract on {} images'.format(pid, len(file_paths)))
tot_count = len(file_paths)
count = 0
scores_dict = {}
for res_path, gt_path in file_paths:
count += 1
if count % 100 == 0:
print('{} done with {}/{}'.format(pid, count, tot_count))
if True:
input_im = Image.open(res_path)
input_im, _ = mtcnn.align(input_im)
if input_im is None:
print('{} skipping {}'.format(pid, res_path))
continue
input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0]
result_im = Image.open(gt_path)
result_im, _ = mtcnn.align(result_im)
if result_im is None:
print('{} skipping {}'.format(pid, gt_path))
continue
result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0]
score = float(input_id.dot(result_id))
scores_dict[os.path.basename(gt_path)] = score
return scores_dict
def parse_args():
parser = ArgumentParser(add_help=False)
parser.add_argument('--num_threads', type=int, default=4)
parser.add_argument('--data_path', type=str, default='results')
parser.add_argument('--gt_path', type=str, default='gt_images')
args = parser.parse_args()
return args
def run(args):
file_paths = []
for f in os.listdir(args.data_path):
image_path = os.path.join(args.data_path, f)
gt_path = os.path.join(args.gt_path, f)
if f.endswith(".jpg") or f.endswith('.png'):
file_paths.append([image_path, gt_path.replace('.png','.jpg')])
file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
pool = mp.Pool(args.num_threads)
print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
tic = time.time()
results = pool.map(extract_on_paths, file_chunks)
scores_dict = {}
for d in results:
scores_dict.update(d)
all_scores = list(scores_dict.values())
mean = np.mean(all_scores)
std = np.std(all_scores)
result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std)
print(result_str)
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
if not os.path.exists(out_path):
os.makedirs(out_path)
with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f:
f.write(result_str)
with open(os.path.join(out_path, 'scores_id.json'), 'w') as f:
json.dump(scores_dict, f)
toc = time.time()
print('Mischief managed in {}s'.format(toc - tic))
if __name__ == '__main__':
args = parse_args()
run(args)