NetsPresso_QA / integrations /sparse /test_simple_fusion_search_integration.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
4.3 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import filecmp
import gzip
import os
import shutil
import unittest
from tqdm import tqdm
from pyserini.fusion import FusionMethod
from pyserini.search import get_topics
from pyserini.search import LuceneFusionSearcher
from pyserini.trectools import TrecRun
from pyserini.util import download_url, download_and_unpack_index
class TestSearchIntegration(unittest.TestCase):
def setUp(self):
download_and_unpack_index('https://git.uwaterloo.ca/jimmylin/cord19-indexes/raw/master/2020-05-01/lucene-index-cord19-abstract-2020-05-01.tar.gz')
download_and_unpack_index('https://git.uwaterloo.ca/jimmylin/cord19-indexes/raw/master/2020-05-01/lucene-index-cord19-full-text-2020-05-01.tar.gz')
download_and_unpack_index('https://git.uwaterloo.ca/jimmylin/cord19-indexes/raw/master/2020-05-01/lucene-index-cord19-paragraph-2020-05-01.tar.gz')
download_url('https://git.uwaterloo.ca/jimmylin/covidex-trec-covid-runs/raw/master/round2/anserini.covid-r2.fusion1.txt.gz', 'runs')
# from https://stackoverflow.com/questions/31028815/how-to-unzip-gz-file-using-python
with gzip.open('runs/anserini.covid-r2.fusion1.txt.gz', 'rb') as f_in:
with open('runs/anserini.covid-r2.fusion1.txt', 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
def test_simple_fusion_searcher(self):
index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
'indexes/lucene-index-cord19-full-text-2020-05-01/',
'indexes/lucene-index-cord19-paragraph-2020-05-01/']
searcher = LuceneFusionSearcher(index_dirs, method=FusionMethod.RRF)
runs, topics = [], get_topics('covid-round2')
for topic in tqdm(sorted(topics.keys())):
query = topics[topic]['question'] + ' ' + topics[topic]['query']
hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
docid_score_pair = [(hit.docid, hit.score) for hit in hits]
run = TrecRun.from_search_results(docid_score_pair, topic=topic)
runs.append(run)
all_topics_run = TrecRun.concat(runs)
all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')
# Only keep topic, docid, and rank. Scores may be slightly different due to floating point precision issues and underlying lib versions.
# TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
# This has also proven to be a somewhat brittle test, see https://github.com/castorini/pyserini/issues/947
# A stopgap for above issue, we're restricting comparison to only top-100 ranks.
#
# Another update (2022/09/17): This test broke again in the Lucene 8->9 upgrade.
# Fixed by restricting comparisons to only top 20.
os.system("""awk '$4 <= 20 {print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
os.system("""awk '$4 <= 20 {print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")
self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
def tearDown(self):
shutil.rmtree('indexes/lucene-index-cord19-abstract-2020-05-01')
shutil.rmtree('indexes/lucene-index-cord19-full-text-2020-05-01')
shutil.rmtree('indexes/lucene-index-cord19-paragraph-2020-05-01')
os.remove('runs/anserini.covid-r2.fusion1.txt.gz')
os.remove('runs/anserini.covid-r2.fusion1.txt')
os.remove('runs/fused.txt')
os.remove('runs/this.txt')
os.remove('runs/that.txt')
if __name__ == '__main__':
unittest.main()