Spaces:
Runtime error
Runtime error
File size: 9,585 Bytes
d6585f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
"""
Official evaluation script for the MS MARCO Document Ranking task.
Authors: Daniel Campos, Rutger van Haasteren, Jimmy Lin
"""
import argparse
import re
import os
from collections import Counter
MaxMRRRank = 100
def autoopen(filename, mode="rt"):
"""
A drop-in for open() that applies automatic compression for .gz and .bz2 file extensions
"""
if not 't' in mode and not 'b' in mode:
mode=mode+'t'
if filename.endswith(".gz"):
import gzip
return gzip.open(filename, mode)
elif filename.endswith(".bz2"):
import bz2
return bz2.open(filename, mode)
return open(filename, mode)
def load_reference_from_stream(f):
"""Load Reference reference relevant document
Args:f (stream): stream to load.
Returns:qids_to_relevant_documentids (dict): dictionary mapping from query_id (int) to relevant document (list of ints).
"""
qids_to_relevant_documentids = {}
for l in f:
try:
l = re.split('[\t\s]', l.strip())
qid = int(l[0])
if qid in qids_to_relevant_documentids:
pass
else:
qids_to_relevant_documentids[qid] = []
qids_to_relevant_documentids[qid].append(l[2])
except:
raise IOError('\"%s\" is not valid format' % l)
return qids_to_relevant_documentids
def load_reference(path_to_reference):
"""Load Reference reference relevant document
Args:path_to_reference (str): path to a file to load.
Returns:qids_to_relevant_documentids (dict): dictionary mapping from query_id (int) to relevant documents (list of ints).
"""
with autoopen(path_to_reference,'r') as f:
qids_to_relevant_documentids = load_reference_from_stream(f)
return qids_to_relevant_documentids
def validate_candidate_has_enough_ranking(qid_to_ranked_candidate_documents):
for qid in qid_to_ranked_candidate_documents:
if len(qid_to_ranked_candidate_documents[qid]) > MaxMRRRank:
print('Too many documents ranked. Please Provide top 100 documents for qid:{}'.format(qid))
def load_candidate_from_stream(f):
"""Load candidate data from a stream.
Args:f (stream): stream to load.
Returns:qid_to_ranked_candidate_documents (dict): dictionary mapping from query_id (int) to a list of 1000 document ids(int) ranked by relevance and importance
"""
qid_to_ranked_candidate_documents = {}
for l in f:
try:
l = l.strip().split('\t')
qid = int(l[0])
did = l[1]
rank = int(l[2])
if qid in qid_to_ranked_candidate_documents:
pass
else:
# By default, all PIDs in the list of 1000 are 0. Only override those that are given
qid_to_ranked_candidate_documents[qid] = []
qid_to_ranked_candidate_documents[qid].append((did,rank))
except:
raise IOError('\"%s\" is not valid format' % l)
validate_candidate_has_enough_ranking(qid_to_ranked_candidate_documents)
print('Quantity of Documents ranked for each query is as expected. Evaluating')
return {qid: sorted(qid_to_ranked_candidate_documents[qid], key=lambda x:(x[1], x[0]), reverse=False) for qid in qid_to_ranked_candidate_documents}
def load_candidate(path_to_candidate):
"""Load candidate data from a file.
Args:path_to_candidate (str): path to file to load.
Returns:qid_to_ranked_candidate_documents (dict): dictionary mapping from query_id (int) to a list of 1000 document ids(int) ranked by relevance and importance
"""
with autoopen(path_to_candidate,'r') as f:
qid_to_ranked_candidate_documents = load_candidate_from_stream(f)
return qid_to_ranked_candidate_documents
def quality_checks_qids(qids_to_relevant_documentids, qids_to_ranked_candidate_documents):
"""Perform quality checks on the dictionaries
Args:
p_qids_to_relevant_documentids (dict): dictionary of query-document mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_documents (dict): dictionary of query-document candidates
Returns:
bool,str: Boolean whether allowed, message to be shown in case of a problem
"""
message = ''
allowed = True
# Create sets of the QIDs for the submitted and reference queries
candidate_set = set(qids_to_ranked_candidate_documents.keys())
ref_set = set(qids_to_relevant_documentids.keys())
# Check that we do not have multiple documents per query
for qid in qids_to_ranked_candidate_documents:
# Remove all zeros from the candidates
duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_documents[qid]).items() if count > 1])
if len(duplicate_pids-set([0])) > 0:
message = "Cannot rank a document multiple times for a single query. QID={qid}, PID={pid}".format(
qid=qid, pid=list(duplicate_pids)[0])
allowed = False
return allowed, message
def compute_metrics(qids_to_relevant_documentids, qids_to_ranked_candidate_documents, exclude_qids):
"""Compute MRR metric
Args:
p_qids_to_relevant_documentids (dict): dictionary of query-document mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_documents (dict): dictionary of query-document candidates
Returns:
dict: dictionary of metrics {'MRR': <MRR Score>}
"""
all_scores = {}
MRR = 0
qids_with_relevant_documents = 0
ranking = []
for qid in qids_to_ranked_candidate_documents:
if qid in qids_to_relevant_documentids and qid not in exclude_qids:
ranking.append(0)
target_pid = qids_to_relevant_documentids[qid]
candidate_pid = qids_to_ranked_candidate_documents[qid]
for i in range(0,len(candidate_pid)):
if candidate_pid[i][0] in target_pid:
MRR += 1/(i + 1)
ranking.pop()
ranking.append(i+1)
break
if len(ranking) == 0:
raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
MRR = MRR/len(qids_to_relevant_documentids)
all_scores['MRR @100'] = MRR
all_scores['QueriesRanked'] = len(set(qids_to_ranked_candidate_documents)-exclude_qids)
return all_scores
def compute_metrics_from_files(path_to_reference, path_to_candidate, exclude_qids, perform_checks=True):
"""Compute MRR metric
Args:
p_path_to_reference_file (str): path to reference file.
Reference file should contain lines in the following format:
QUERYID\tdocumentID
Where documentID is a relevant document for a query. Note QUERYID can repeat on different lines with different documentIDs
p_path_to_candidate_file (str): path to candidate file.
Candidate file sould contain lines in the following format:
QUERYID\tdocumentID1\tRank
If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is
QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID
Where the values are separated by tabs and ranked in order of relevance
Returns:
dict: dictionary of metrics {'MRR': <MRR Score>}
"""
qids_to_relevant_documentids = load_reference(path_to_reference)
qids_to_ranked_candidate_documents = load_candidate(path_to_candidate)
if perform_checks:
allowed, message = quality_checks_qids(qids_to_relevant_documentids, qids_to_ranked_candidate_documents)
if message != '': print(message)
return compute_metrics(qids_to_relevant_documentids, qids_to_ranked_candidate_documents, exclude_qids)
def load_exclude(path_to_exclude_folder):
"""Load QIDS for queries to exclude
Args:
path_to_exclude_folder (str): path to folder where exclude files are located
Returns:
set: a set with all qid's to exclude
"""
qids = set()
# List all files in a directory using os.listdir
for a_file in os.listdir(path_to_exclude_folder):
if os.path.isfile(os.path.join(path_to_exclude_folder, a_file)):
with autoopen(os.path.join(path_to_exclude_folder, a_file), 'r') as f:
f.readline() #header
for l in f:
qids.add(int(l.split('\t')[0]))
print("{} excluded qids loaded".format(len(qids)))
return qids
def main(args):
# Load optional excludes.
exclude_qids = set()
if args.exclude:
exclude_qids = load_exclude(args.exclude)
# Load run and judgments.
path_to_candidate = args.run
path_to_reference = args.judgments
metrics = compute_metrics_from_files(path_to_reference, path_to_candidate, exclude_qids)
print('#####################')
for metric in sorted(metrics):
print('{}: {}'.format(metric, metrics[metric]))
print('#####################')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Official evaluation script for the MS MARCO Document Ranking task.')
parser.add_argument('--run', type=str, metavar='file', required=True, help='Run file.')
parser.add_argument('--judgments', type=str, metavar='file', required=True, help='Judgments.')
parser.add_argument('--exclude', type=str, metavar='file', required=False, help='Exclude directory.')
main(parser.parse_args())
|