NetsPresso_QA / analyze_answer_inclusion_in_retrieval.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
import json
from pyserini.search.lucene import LuceneSearcher
from tqdm import tqdm
def convert_unicode_to_normal(data):
if isinstance(data, str):
return data.encode('utf-8').decode('utf-8')
elif isinstance(data, list):
assert(isinstance(data[0], str))
return [sample.encode('utf-8').decode('utf-8') for sample in data]
else:
raise ValueError
K=30
index_dir="/root/indexes/index-wikipedia-dpr-20210120" # lucene
runfile_path=f"runs/q=NQtest_c=wikidpr_m=bm25_k={K}.run" # bm25
qafile_path="/root/nota-fairseq/examples/information_retrieval/open_domain_data/NQ/qa_pairs/test.jsonl"
logging_path="logging_q=NQ_c=wiki_including_ans.jsonl"
# define searcher with pre-built indexes
searcher = LuceneSearcher(index_dir=index_dir)
# v2. read qa first (due to runfile query name sort)
print("read qa file")
pair_by_qid = {}
with open(qafile_path, 'r') as fr_qa:
for pair in tqdm(fr_qa):
pair_data = json.loads(pair)
qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list
pair_by_qid[qid] = {'query': query, 'answers':answers}
print("check retrieved passage include answer")
qid_with_ans_in_retrieval = []
with open(runfile_path, 'r') as fr_run, open(logging_path, 'w') as fw_log:
for result in tqdm(fr_run):
fields = result.split(' ')
assert(len(fields) == 6) # qid q_type pid k score engine
qid_, pid = fields[0], fields[2]
assert(qid_ in pair_by_qid.keys())
query, answers = pair_by_qid[qid_]['query'], pair_by_qid[qid_]['answers']
# get passage
psg_txt = searcher.doc(pid)
psg_txt = psg_txt.raw()
psg_txt = json.loads(psg_txt)
psg_txt = psg_txt['contents'].strip()
psg_txt = convert_unicode_to_normal(psg_txt)
# check if passage contains answer
#if any([ans in psg_txt for ans in answers]):
for ans in answers:
if ans in psg_txt:
log_w = {
"qid": qid_,
"pid": pid,
"query": query,
"answer": ans,
"passage": psg_txt
}
fw_log.write(json.dumps(log_w, ensure_ascii=False) + '\n')
if qid_ not in qid_with_ans_in_retrieval:
qid_with_ans_in_retrieval.append(qid_)
break # don't have to count check multiple answer in passage
print(f"#qid in test set: {len(pair_by_qid.keys())}, #qid having answer with retrieval(BM25, K={K}): {len(qid_with_ans_in_retrieval)}, Recall = {len(qid_with_ans_in_retrieval)/len(pair_by_qid.keys())*100}")
# v1
"""
with open(runfile_path, 'r') as fr_run, open(qafile_path, 'r') as fr_qa:
for pair in tqdm(fr_qa):
pair_data = json.loads(pair)
qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list
for k in range(K):
result=fr_run.readline()
print(result)
fields = result.split(' ')
assert(len(fields) == 6) # qid q_type pid k score engine
qid_, pid = fields[0], fields[2]
assert(qid == qid_), f"qid={qid}, qid_={qid_} should be same"
# get passage
psg_txt = searcher.doc(pid)
psg_txt = psg_txt.raw()
psg_txt = json.loads(psg_txt)
psg_txt = psg_txt['contents'].strip()
psg_txt = convert_unicode_to_normal(psg_txt)
# check if passage contains answer
if any([ans in psg_txt for ans in answers]):
import pdb
pdb.set_trace()
"""