File size: 3,689 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
from pyserini.search.lucene import LuceneSearcher

from tqdm import tqdm

def convert_unicode_to_normal(data):
    if isinstance(data, str):
        return data.encode('utf-8').decode('utf-8')    
    elif isinstance(data, list):
        assert(isinstance(data[0], str))
        return [sample.encode('utf-8').decode('utf-8') for sample in data]    
    else:
        raise ValueError

K=30
index_dir="/root/indexes/index-wikipedia-dpr-20210120" # lucene
runfile_path=f"runs/q=NQtest_c=wikidpr_m=bm25_k={K}.run" # bm25
qafile_path="/root/nota-fairseq/examples/information_retrieval/open_domain_data/NQ/qa_pairs/test.jsonl"
logging_path="logging_q=NQ_c=wiki_including_ans.jsonl"

# define searcher with pre-built indexes
searcher = LuceneSearcher(index_dir=index_dir)

# v2. read qa first (due to runfile query name sort)
print("read qa file")
pair_by_qid = {}
with open(qafile_path, 'r') as fr_qa:
    for pair in tqdm(fr_qa):
        pair_data = json.loads(pair)
        qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list
        pair_by_qid[qid] = {'query': query, 'answers':answers}
        
print("check retrieved passage include answer")
qid_with_ans_in_retrieval = []
with open(runfile_path, 'r') as fr_run, open(logging_path, 'w') as fw_log:
    for result in tqdm(fr_run):
        fields = result.split(' ')
        assert(len(fields) == 6) # qid q_type pid k score engine

        qid_, pid = fields[0], fields[2]
        assert(qid_ in pair_by_qid.keys())    
        query, answers = pair_by_qid[qid_]['query'], pair_by_qid[qid_]['answers']

        # get passage
        psg_txt = searcher.doc(pid)
        psg_txt = psg_txt.raw()
        psg_txt = json.loads(psg_txt)
        psg_txt = psg_txt['contents'].strip()
        psg_txt = convert_unicode_to_normal(psg_txt)

        # check if passage contains answer
        #if any([ans in psg_txt for ans in answers]):
        for ans in answers:
            if ans in psg_txt:
                log_w = {
                    "qid": qid_,
                    "pid": pid,
                    "query": query,
                    "answer": ans,
                    "passage": psg_txt
                }
                fw_log.write(json.dumps(log_w, ensure_ascii=False) + '\n')

                if qid_ not in qid_with_ans_in_retrieval:
                    qid_with_ans_in_retrieval.append(qid_)
                break # don't have to count check multiple answer in passage


print(f"#qid in test set: {len(pair_by_qid.keys())}, #qid having answer with retrieval(BM25, K={K}): {len(qid_with_ans_in_retrieval)}, Recall = {len(qid_with_ans_in_retrieval)/len(pair_by_qid.keys())*100}")

# v1
"""
with open(runfile_path, 'r') as fr_run, open(qafile_path, 'r') as fr_qa:
    for pair in tqdm(fr_qa):
        pair_data = json.loads(pair)
        qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list

        for k in range(K):
            result=fr_run.readline()
            print(result)

            fields = result.split(' ')
            assert(len(fields) == 6) # qid q_type pid k score engine

            qid_, pid = fields[0], fields[2]

            assert(qid == qid_), f"qid={qid}, qid_={qid_} should be same"

            # get passage
            psg_txt = searcher.doc(pid)
            psg_txt = psg_txt.raw()
            psg_txt = json.loads(psg_txt)
            psg_txt = psg_txt['contents'].strip()
            psg_txt = convert_unicode_to_normal(psg_txt)

            # check if passage contains answer
            if any([ans in psg_txt for ans in answers]):
                import pdb
                pdb.set_trace()
"""