NetsPresso_QA / pyserini /eval /trec_eval.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
3.85 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Example usage
# python -m pyserini.eval.trec_eval -m ndcg_cut.10,20 -m all_trec qrels.dev.small.tsv runs/run.Colbert.txt -remove-unjudged -cutoffs.20,50
import os
import re
import subprocess
import sys
import platform
import pandas as pd
import tempfile
from pyserini.search import get_qrels_file
from pyserini.util import download_evaluation_script
script_path = download_evaluation_script('trec_eval')
cmd_prefix = ['java', '-jar', script_path]
args = sys.argv
# Option to discard non-judged hits in run file
judged_docs_only = ''
judged_result = []
cutoffs = []
if '-remove-unjudged' in args:
judged_docs_only = args.pop(args.index('-remove-unjudged'))
if any([i.startswith('judged.') for i in args]):
# Find what position the arg is in.
idx = [i.startswith('judged.') for i in args].index(True)
cutoffs = args.pop(idx)
cutoffs = list(map(int, cutoffs[7:].split(',')))
# Get rid of the '-m' before the 'judged.xxx' option
args.pop(idx-1)
temp_file = ''
if len(args) > 1:
if not os.path.exists(args[-2]):
args[-2] = get_qrels_file(args[-2])
if os.path.exists(args[-1]):
# Convert run to trec if it's on msmarco
with open(args[-1]) as f:
first_line = f.readline()
if 'Q0' not in first_line:
temp_file = tempfile.NamedTemporaryFile(delete=False).name
print('msmarco run detected. Converting to trec...')
run = pd.read_csv(args[-1], delim_whitespace=True, header=None, names=['query_id', 'doc_id', 'rank'])
run['score'] = 1 / run['rank']
run.insert(1, 'Q0', 'Q0')
run['name'] = 'TEMPRUN'
run.to_csv(temp_file, sep='\t', header=None, index=None)
args[-1] = temp_file
run = pd.read_csv(args[-1], delim_whitespace=True, header=None)
qrels = pd.read_csv(args[-2], delim_whitespace=True, header=None)
# cast doc_id column as string
run[0] = run[0].astype(str)
qrels[0] = qrels[0].astype(str)
# Discard non-judged hits
if judged_docs_only:
if not temp_file:
temp_file = tempfile.NamedTemporaryFile(delete=False).name
judged_indexes = pd.merge(run[[0,2]].reset_index(), qrels[[0,2]], on = [0,2])['index']
run = run.loc[judged_indexes]
run.to_csv(temp_file, sep='\t', header=None, index=None)
args[-1] = temp_file
# Measure judged@cutoffs
for cutoff in cutoffs:
run_cutoff = run.groupby(0).head(cutoff)
judged = len(pd.merge(run_cutoff[[0,2]], qrels[[0,2]], on = [0,2])) / len(run_cutoff)
metric_name = f'judged_{cutoff}'
judged_result.append(f'{metric_name:22}\tall\t{judged:.4f}')
cmd = cmd_prefix + args[1:]
else:
cmd = cmd_prefix
print(f'Running command: {cmd}')
shell = platform.system() == "Windows"
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
stdout, stderr = process.communicate()
if stderr:
print(stderr.decode("utf-8"))
print('Results:')
print(stdout.decode("utf-8").rstrip())
for judged in judged_result:
print(judged)
if temp_file:
os.remove(temp_file)