import json from pyserini.search.lucene import LuceneSearcher from tqdm import tqdm def convert_unicode_to_normal(data): if isinstance(data, str): return data.encode('utf-8').decode('utf-8') elif isinstance(data, list): assert(isinstance(data[0], str)) return [sample.encode('utf-8').decode('utf-8') for sample in data] else: raise ValueError K=30 index_dir="/root/indexes/index-wikipedia-dpr-20210120" # lucene runfile_path=f"runs/q=NQtest_c=wikidpr_m=bm25_k={K}.run" # bm25 qafile_path="/root/nota-fairseq/examples/information_retrieval/open_domain_data/NQ/qa_pairs/test.jsonl" logging_path="logging_q=NQ_c=wiki_including_ans.jsonl" # define searcher with pre-built indexes searcher = LuceneSearcher(index_dir=index_dir) # v2. read qa first (due to runfile query name sort) print("read qa file") pair_by_qid = {} with open(qafile_path, 'r') as fr_qa: for pair in tqdm(fr_qa): pair_data = json.loads(pair) qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list pair_by_qid[qid] = {'query': query, 'answers':answers} print("check retrieved passage include answer") qid_with_ans_in_retrieval = [] with open(runfile_path, 'r') as fr_run, open(logging_path, 'w') as fw_log: for result in tqdm(fr_run): fields = result.split(' ') assert(len(fields) == 6) # qid q_type pid k score engine qid_, pid = fields[0], fields[2] assert(qid_ in pair_by_qid.keys()) query, answers = pair_by_qid[qid_]['query'], pair_by_qid[qid_]['answers'] # get passage psg_txt = searcher.doc(pid) psg_txt = psg_txt.raw() psg_txt = json.loads(psg_txt) psg_txt = psg_txt['contents'].strip() psg_txt = convert_unicode_to_normal(psg_txt) # check if passage contains answer #if any([ans in psg_txt for ans in answers]): for ans in answers: if ans in psg_txt: log_w = { "qid": qid_, "pid": pid, "query": query, "answer": ans, "passage": psg_txt } fw_log.write(json.dumps(log_w, ensure_ascii=False) + '\n') if qid_ not in qid_with_ans_in_retrieval: qid_with_ans_in_retrieval.append(qid_) break # don't have to count check multiple answer in passage print(f"#qid in test set: {len(pair_by_qid.keys())}, #qid having answer with retrieval(BM25, K={K}): {len(qid_with_ans_in_retrieval)}, Recall = {len(qid_with_ans_in_retrieval)/len(pair_by_qid.keys())*100}") # v1 """ with open(runfile_path, 'r') as fr_run, open(qafile_path, 'r') as fr_qa: for pair in tqdm(fr_qa): pair_data = json.loads(pair) qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list for k in range(K): result=fr_run.readline() print(result) fields = result.split(' ') assert(len(fields) == 6) # qid q_type pid k score engine qid_, pid = fields[0], fields[2] assert(qid == qid_), f"qid={qid}, qid_={qid_} should be same" # get passage psg_txt = searcher.doc(pid) psg_txt = psg_txt.raw() psg_txt = json.loads(psg_txt) psg_txt = psg_txt['contents'].strip() psg_txt = convert_unicode_to_normal(psg_txt) # check if passage contains answer if any([ans in psg_txt for ans in answers]): import pdb pdb.set_trace() """