File size: 3,255 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Integration tests for commands in Pradeep et al. resource paper at ECIR 2023."""

import os
import unittest

from integrations.utils import clean_files, run_command, parse_score, parse_score_qa


class TestECIR2023(unittest.TestCase):
    def setUp(self):
        self.temp_files = []

    def test_section5_sub2_first(self):
        """Sample code of the first command in Section 5.2."""
        metrics = ["Top5", "Top20", "Top100"]
        ground_truth = [73.8, 84.27, 89.34]

        output_file = 'runs/run.nq-test.dkrr.trec'
        json_file = 'runs/run.nq-test.dkrr.json'
        self.temp_files.append(output_file)
        self.temp_files.append(json_file)

        # retrieval
        run_cmd = f'python -m pyserini.search.faiss \
                      --index wikipedia-dpr-dkrr-nq \
                      --topics nq-test \
                      --encoder castorini/dkrr-dpr-nq-retriever \
                      --output {output_file} --query-prefix question: \
                      --threads 72 --batch-size 72 \
                      --hits 100'
        status = os.system(run_cmd)
        self.assertEqual(status, 0)

        # conversion
        convert_cmd = f'python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run \
                        --topics nq-test \
                        --index wikipedia-dpr \
                        --input {output_file} \
                        --output {json_file}'
        status = os.system(convert_cmd)
        self.assertEqual(status, 0)

        # evaluation
        eval_cmd = f'python -m pyserini.eval.evaluate_dpr_retrieval \
                       --retrieval {json_file} \
                       --topk 5 20 100'
        stdout, stderr = run_command(eval_cmd)
        
        scores = [] 
        for mt in metrics: 
            scores.append(parse_score_qa(stdout, mt, 4) * 100)

        for score in zip(scores, ground_truth):
            self.assertAlmostEqual(score[0], score[1], delta=0.02)

    def test_section5_sub2_second(self):
        """Sample code of the second command in Section 5.2."""

        cmd_nq = 'python scripts/repro_matrix/run_all_odqa.py --topics nq'
        cmd_tqa = 'python scripts/repro_matrix/run_all_odqa.py --topics nq'

        # run both commands, check if all tests passed (i.e., returned OK)
        stdout_nq, stderr_nq = run_command(cmd_nq)
        self.assertEqual(stdout_nq.count('[OK]'), 21)

        stdout_tqa, stderr_tqa = run_command(cmd_tqa)
        self.assertEqual(stdout_tqa.count('[OK]'), 21)

    def tearDown(self):
        clean_files(self.temp_files)


if __name__ == '__main__':
    unittest.main()