Spaces:
Running
Running
import json | |
import pkg_resources | |
from collections import OrderedDict | |
# Paths | |
data_path = pkg_resources.resource_filename('spiga', 'data/annotations') | |
def main(): | |
import argparse | |
pars = argparse.ArgumentParser(description='Benchmark alignments evaluator') | |
pars.add_argument('pred_file', nargs='+', type=str, help='Absolute path to the prediction json file (Multi file)') | |
pars.add_argument('--eval', nargs='+', type=str, default=['lnd'], | |
choices=['lnd', 'pose'], help='Evaluation modes') | |
pars.add_argument('-s', '--save', action='store_true', help='Save results') | |
args = pars.parse_args() | |
for pred_file in args.pred_file: | |
benchmark = get_evaluator(pred_file, args.eval, args.save) | |
benchmark.metrics() | |
class Evaluator: | |
def __init__(self, data_file, evals=(), save=True, process_err=True): | |
# Inputs | |
self.data_file = data_file | |
self.evals = evals | |
self.save = save | |
# Paths | |
data_name = data_file.split('/')[-1] | |
self.data_dir = data_file.split(data_name)[0] | |
# Information from name | |
data_name = data_name.split('.')[0] | |
data_name = data_name.split('_') | |
self.data_type = data_name[-1] | |
self.database = data_name[-2] | |
# Load predictions and annotations | |
anns_file = data_path + '/%s/%s.json' % (self.database, self.data_type) | |
self.anns = self.load_files(anns_file) | |
self.pred = self.load_files(data_file) | |
# Compute errors | |
self.error = OrderedDict() | |
self.error_pimg = OrderedDict() | |
self.metrics_log = OrderedDict() | |
if process_err: | |
self.compute_error(self.anns, self.pred) | |
def compute_error(self, anns, pred, select_ids=None): | |
database_ref = [self.database, self.data_type] | |
for eval in self.evals: | |
self.error[eval.name] = eval.compute_error(anns, pred, database_ref, select_ids) | |
self.error_pimg = eval.get_pimg_err(self.error_pimg) | |
return self.error | |
def metrics(self): | |
for eval in self.evals: | |
self.metrics_log[eval.name] = eval.metrics() | |
if self.save: | |
file_name = self.data_dir + '/metrics_%s_%s.txt' % (self.database, self.data_type) | |
with open(file_name, 'w') as file: | |
file.write(str(self)) | |
return self.metrics_log | |
def load_files(self, input_file): | |
with open(input_file) as jsonfile: | |
data = json.load(jsonfile) | |
return data | |
def _dict2text(self, name, dictionary, num_tab=1): | |
prev_tabs = '\t'*num_tab | |
text = '%s {\n' % name | |
for k, v in dictionary.items(): | |
if isinstance(v, OrderedDict) or isinstance(v, dict): | |
text += '{}{}'.format(prev_tabs, self._dict2text(k, v, num_tab=num_tab+1)) | |
else: | |
text += '{}{}: {}\n'.format(prev_tabs, k, v) | |
text += (prev_tabs + '}\n') | |
return text | |
def __str__(self): | |
state_dict = self.metrics_log | |
text = self._dict2text('Metrics', state_dict) | |
return text | |
def get_evaluator(pred_file, evaluate=('lnd', 'pose'), save=False, process_err=True): | |
eval_list = [] | |
if "lnd" in evaluate: | |
import spiga.eval.benchmark.metrics.landmarks as mlnd | |
eval_list.append(mlnd.MetricsLandmarks()) | |
if "pose" in evaluate: | |
import spiga.eval.benchmark.metrics.pose as mpose | |
eval_list.append(mpose.MetricsHeadpose()) | |
return Evaluator(pred_file, evals=eval_list, save=save, process_err=process_err) | |
if __name__ == '__main__': | |
main() | |