File size: 6,372 Bytes
f76d30f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# -*- coding: utf-8 -*-
'''
This script computes the recall scores given the ground-truth annotations and predictions.
'''
import json
import sys
import os
import string
import numpy as np
import time
NUM_K = 10
def read_submission(submit_path, reference, k=5):
# check whether the path of submitted file exists
if not os.path.exists(submit_path):
raise Exception("The submission file is not found!")
submission_dict = {}
ref_qids = set(reference.keys())
with open(submit_path, encoding="utf-8") as fin:
for line in fin:
line = line.strip()
try:
pred_obj = json.loads(line)
except:
raise Exception('Cannot parse this line into json object: {}'.format(line))
if "text_id" not in pred_obj:
raise Exception('There exists one line not containing text_id: {}'.format(line))
if not isinstance(pred_obj['text_id'], int):
raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
qid = pred_obj["text_id"]
if "image_ids" not in pred_obj:
raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
image_ids = pred_obj["image_ids"]
if not isinstance(image_ids, list):
raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
# check whether there are K products for each text
if len(image_ids) != k:
raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))
# check whether there exist an invalid prediction for any text
for rank, image_id in enumerate(image_ids):
if not isinstance(image_id, int):
raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
# check whether there are duplicate predicted products for a single text
if len(set(image_ids)) != k:
raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
submission_dict[qid] = image_ids # here we save the list of product ids
# check if any text is missing in the submission
pred_qids = set(submission_dict.keys())
nopred_qids = ref_qids - pred_qids
if len(nopred_qids) != 0:
raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))
return submission_dict
def dump_2_json(info, path):
with open(path, 'w', encoding="utf-8") as output_json_file:
json.dump(info, output_json_file)
def report_error_msg(detail, showMsg, out_p):
error_dict=dict()
error_dict['errorDetail']=detail
error_dict['errorMsg']=showMsg
error_dict['score']=0
error_dict['scoreJson']={}
error_dict['success']=False
dump_2_json(error_dict,out_p)
def report_score(r1, r5, r10, out_p):
result = dict()
result['success']=True
mean_recall = (r1 + r5 + r10) / 3.0
result['score'] = mean_recall * 100
result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
dump_2_json(result,out_p)
def read_reference(path):
fin = open(path, encoding="utf-8")
reference = dict()
for line in fin:
line = line.strip()
obj = json.loads(line)
reference[obj['text_id']] = obj['image_ids']
return reference
def compute_score(golden_file, predict_file):
# read ground-truth
reference = read_reference(golden_file)
# read predictions
k = 10
predictions = read_submission(predict_file, reference, k)
# compute score for each text
r1_stat, r5_stat, r10_stat = 0, 0, 0
for qid in reference.keys():
ground_truth_ids = set(reference[qid])
top10_pred_ids = predictions[qid]
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
r1_stat += 1
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
r5_stat += 1
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
r10_stat += 1
# the higher score, the better
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
mean_recall = (r1 + r5 + r10) / 3.0
result = [mean_recall, r1, r5, r10]
result = [score * 100 for score in result]
return result
if __name__=="__main__":
# the path of answer json file (eg. test_queries_answers.jsonl)
standard_path = sys.argv[1]
# the path of prediction file (eg. example_pred.jsonl)
submit_path = sys.argv[2]
# the score will be dumped into this output json file
out_path = sys.argv[3]
print("Read standard from %s" % standard_path)
print("Read user submit file from %s" % submit_path)
try:
# read ground-truth
reference = read_reference(standard_path)
# read predictions
k = 10
predictions = read_submission(submit_path, reference, k)
# compute score for each text
r1_stat, r5_stat, r10_stat = 0, 0, 0
for qid in reference.keys():
ground_truth_ids = set(reference[qid])
top10_pred_ids = predictions[qid]
if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
r1_stat += 1
if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
r5_stat += 1
if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
r10_stat += 1
# the higher score, the better
r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
report_score(r1, r5, r10, out_path)
print("The evaluation finished successfully.")
except Exception as e:
report_error_msg(e.args[0], e.args[0], out_path)
print("The evaluation failed: {}".format(e.args[0])) |