# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import shutil import tarfile import unittest from random import randint from urllib.request import urlretrieve import json import heapq from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import MultinomialNB from pyserini import analysis, search from pyserini.index.lucene import IndexReader from pyserini.pyclass import JString from pyserini.vectorizer import BM25Vectorizer, TfidfVectorizer class TestIndexUtils(unittest.TestCase): def setUp(self): # Download pre-built CACM index built using Lucene 9; append a random value to avoid filename clashes. r = randint(0, 10000000) self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene9-index.cacm.tar.gz' self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) self.index_dir = 'index{}/'.format(r) _, _ = urlretrieve(self.collection_url, self.tarball_name) tarball = tarfile.open(self.tarball_name) tarball.extractall(self.index_dir) tarball.close() self.index_path = os.path.join(self.index_dir, 'lucene9-index.cacm') self.searcher = search.LuceneSearcher(self.index_path) self.index_reader = IndexReader(self.index_path) self.temp_folders = [] # The current directory depends on if you're running inside an IDE or from command line. curdir = os.getcwd() if curdir.endswith('tests'): self.emoji_corpus_path = '../tests/resources/sample_collection_json_emoji' else: self.emoji_corpus_path = 'tests/resources/sample_collection_json_emoji' # See https://github.com/castorini/pyserini/issues/770 # tldr -- a longstanding issue about whether we need the `encode` in `JString(my_str.encode('utf-8'))`. # As it turns out, the solution is to remove the `JString` wrapping, which also has performance benefits as well. # See: # - https://github.com/castorini/pyserini/pull/862 # - https://github.com/castorini/pyserini/issues/841 def test_doc_vector_emoji_test(self): index_dir = 'temp_index' self.temp_folders.append(index_dir) cmd1 = f'python -m pyserini.index.lucene -collection JsonCollection ' + \ f'-generator DefaultLuceneDocumentGenerator ' + \ f'-threads 1 -input {self.emoji_corpus_path} -index {index_dir} -storeDocvectors' _ = os.system(cmd1) temp_index_reader = IndexReader(index_dir) df, cf = temp_index_reader.get_term_counts('emoji') self.assertEqual(df, 1) self.assertEqual(cf, 1) df, cf = temp_index_reader.get_term_counts('šŸ™‚') self.assertEqual(df, 1) self.assertEqual(cf, 1) doc_vector = temp_index_reader.get_document_vector('doc1') self.assertEqual(doc_vector['emoji'], 1) self.assertEqual(doc_vector['šŸ™‚'], 1) self.assertEqual(doc_vector['šŸ˜€'], 1) def test_tfidf_vectorizer_train(self): vectorizer = TfidfVectorizer(self.index_path, min_df=5) train_docs = ['CACM-0239', 'CACM-0440', 'CACM-3168', 'CACM-3169'] train_labels = [1, 1, 0, 0] test_docs = ['CACM-0634', 'CACM-3134'] train_vectors = vectorizer.get_vectors(train_docs) test_vectors = vectorizer.get_vectors(test_docs) clf = MultinomialNB() clf.fit(train_vectors, train_labels) pred = clf.predict_proba(test_vectors) self.assertAlmostEqual(0.49975694, pred[0][0], places=8) self.assertAlmostEqual(0.50024306, pred[0][1], places=8) self.assertAlmostEqual(0.51837413, pred[1][0], places=8) self.assertAlmostEqual(0.48162587, pred[1][1], places=8) def test_bm25_vectorizer_train(self): vectorizer = BM25Vectorizer(self.index_path, min_df=5) train_docs = ['CACM-0239', 'CACM-0440', 'CACM-3168', 'CACM-3169'] train_labels = [1, 1, 0, 0] test_docs = ['CACM-0634', 'CACM-3134'] train_vectors = vectorizer.get_vectors(train_docs) test_vectors = vectorizer.get_vectors(test_docs) clf = LogisticRegression() clf.fit(train_vectors, train_labels) pred = clf.predict_proba(test_vectors) self.assertAlmostEqual(0.4629749, pred[0][0], places=8) self.assertAlmostEqual(0.5370251, pred[0][1], places=8) self.assertAlmostEqual(0.48288416, pred[1][0], places=8) self.assertAlmostEqual(0.51711584, pred[1][1], places=8) def test_tfidf_vectorizer(self): vectorizer = TfidfVectorizer(self.index_path, min_df=5) result = vectorizer.get_vectors(['CACM-0239', 'CACM-0440'], norm=None) self.assertAlmostEqual(result[0, 190], 2.907369334264736, places=8) self.assertAlmostEqual(result[1, 391], 0.07516490235060004, places=8) def test_bm25_vectorizer(self): vectorizer = BM25Vectorizer(self.index_path, min_df=5) result = vectorizer.get_vectors(['CACM-0239', 'CACM-0440'], norm=None) self.assertAlmostEqual(result[0, 190], 1.7513844966888428, places=8) self.assertAlmostEqual(result[1, 391], 0.03765463829040527, places=8) def test_vectorizer_query(self): vectorizer = BM25Vectorizer(self.index_path, min_df=5) result = vectorizer.get_query_vector('this is a query to test query vector') self.assertEqual(result[0, 2703], 2) self.assertEqual(result[0, 3078], 1) self.assertEqual(result[0, 3204], 1) def test_terms_count(self): # We're going to iterate through the index and make sure we have the correct number of terms. self.assertEqual(sum(1 for x in self.index_reader.terms()), 14363) def test_terms_contents(self): # We're going to examine the first two index terms to make sure the statistics are correct. iterator = self.index_reader.terms() index_term = next(iterator) self.assertEqual(index_term.term, '0') self.assertEqual(index_term.df, 19) self.assertEqual(index_term.cf, 30) index_term = next(iterator) self.assertEqual(index_term.term, '0,1') self.assertEqual(index_term.df, 1) self.assertEqual(index_term.cf, 1) def test_analyze(self): self.assertEqual(' '.join(self.index_reader.analyze('retrieval')), 'retriev') self.assertEqual(' '.join(self.index_reader.analyze('rapid retrieval, space economy')), 'rapid retriev space economi') tokenizer = analysis.get_lucene_analyzer(stemming=False) self.assertEqual(' '.join(self.index_reader.analyze('retrieval', analyzer=tokenizer)), 'retrieval') self.assertEqual(' '.join(self.index_reader.analyze('rapid retrieval, space economy', analyzer=tokenizer)), 'rapid retrieval space economy') # Test utf encoding: self.assertEqual(self.index_reader.analyze('zoƶlogy')[0], 'zoƶlog') self.assertEqual(self.index_reader.analyze('zoƶlogy', analyzer=tokenizer)[0], 'zoƶlogy') def test_term_stats(self): df, cf = self.index_reader.get_term_counts('retrieval') self.assertEqual(df, 138) self.assertEqual(cf, 275) df, cf = self.index_reader.get_term_counts('information retrieval') self.assertEqual(df, 74) self.assertEqual(cf, None) df_no_stem, cf_no_stem = self.index_reader.get_term_counts('retrieval', analyzer=None) # 'retrieval' does not occur as a stemmed word, only 'retriev' does. self.assertEqual(df_no_stem, 0) self.assertEqual(cf_no_stem, 0) df_no_stopword, cf_no_stopword = self.index_reader.get_term_counts('on', analyzer=None) self.assertEqual(df_no_stopword, 326) self.assertEqual(cf_no_stopword, 443) # Should gracefully handle non-existent term. df, cf = self.index_reader.get_term_counts('sdgsc') self.assertEqual(df, 0) self.assertEqual(cf, 0) def test_postings1(self): term = 'retrieval' postings = list(self.index_reader.get_postings_list(term)) self.assertEqual(len(postings), 138) self.assertEqual(postings[0].docid, 238) self.assertEqual(self.index_reader.convert_internal_docid_to_collection_docid(postings[0].docid), 'CACM-0239') self.assertEqual(postings[0].tf, 1) self.assertEqual(len(postings[0].positions), 1) self.assertEqual(postings[-1].docid, 3168) self.assertEqual(self.index_reader.convert_internal_docid_to_collection_docid(postings[-1].docid), 'CACM-3169') self.assertEqual(postings[-1].tf, 1) self.assertEqual(len(postings[-1].positions), 1) def test_postings2(self): self.assertIsNone(self.index_reader.get_postings_list('asdf')) postings = list(self.index_reader.get_postings_list('retrieval')) self.assertEqual(len(postings), 138) # If we don't analyze, then we can't find the postings list: self.assertIsNone(self.index_reader.get_postings_list('retrieval', analyzer=None)) # Supply the analyzed form directly, and we're good: postings = list(self.index_reader.get_postings_list('retriev', analyzer=None)) self.assertEqual(len(postings), 138) postings = list(self.index_reader.get_postings_list(self.index_reader.analyze('retrieval')[0], analyzer=None)) self.assertEqual(len(postings), 138) # Test utf encoding: self.assertEqual(self.index_reader.get_postings_list('zoƶlogy'), None) self.assertEqual(self.index_reader.get_postings_list('zoƶlogy', analyzer=None), None) self.assertEqual(self.index_reader.get_postings_list('zoƶlogy'), None) def test_doc_vector(self): doc_vector = self.index_reader.get_document_vector('CACM-3134') self.assertEqual(len(doc_vector), 94) self.assertEqual(doc_vector['inform'], 8) self.assertEqual(doc_vector['retriev'], 7) def test_doc_vector_invalid(self): self.assertTrue(self.index_reader.get_document_vector('foo') is None) def test_doc_vector_matches_index(self): # From the document vector, look up the term frequency of "information". doc_vector = self.index_reader.get_document_vector('CACM-3134') self.assertEqual(doc_vector['inform'], 8) # Now look up the postings list for "information". term = 'information' postings_list = list(self.index_reader.get_postings_list(term)) for i in range(len(postings_list)): # Go through the postings and find the matching document. if self.index_reader.convert_internal_docid_to_collection_docid(postings_list[i].docid) == 'CACM-3134': # The tf values should match. self.assertEqual(postings_list[i].tf, 8) def test_term_position(self): term_positions = self.index_reader.get_term_positions('CACM-3134') self.assertEqual(len(term_positions), 94) self.assertEqual(term_positions['inform'], [7,24,36,46,60,112,121,159]) self.assertEqual(term_positions['retriev'], [10,20,44,132,160,164,172]) def test_term_position_invalid(self): self.assertTrue(self.index_reader.get_term_positions('foo') is None) def test_term_position_matches_index(self): # From the term positions mapping, look up the position list of "information". term_positions = self.index_reader.get_term_positions('CACM-3134') self.assertEqual(term_positions['inform'], [7,24,36,46,60,112,121,159]) # Now look up the postings list for "information". term = 'information' postings_list = list(self.index_reader.get_postings_list(term)) for i in range(len(postings_list)): # Go through the postings and find the matching document. if self.index_reader.convert_internal_docid_to_collection_docid(postings_list[i].docid) == 'CACM-3134': # The position list should match. self.assertEqual(postings_list[i].positions, [7, 24, 36, 46, 60, 112, 121, 159]) def test_doc_invalid(self): self.assertTrue(self.index_reader.doc('foo') is None) self.assertTrue(self.index_reader.doc_contents('foo') is None) self.assertTrue(self.index_reader.doc_raw('foo') is None) self.assertTrue(self.index_reader.doc_by_field('foo', 'bar') is None) def test_doc_raw(self): raw = self.index_reader.doc('CACM-3134').raw() self.assertTrue(isinstance(raw, str)) lines = raw.splitlines() self.assertEqual(len(lines), 55) # Note that the raw document contents will still have HTML tags. self.assertEqual(lines[0], '') self.assertEqual(lines[4], 'The Use of Normal Multiplication Tables') self.assertEqual(lines[29], 'rapid retrieval, space economy') # Now that we've verified the 'raw', check that alternative ways of fetching give the same results. self.assertEqual(raw, self.index_reader.doc_raw('CACM-3134')) self.assertEqual(raw, self.index_reader.doc('CACM-3134').raw()) self.assertEqual(raw, self.index_reader.doc('CACM-3134').get('raw')) self.assertEqual(raw, self.index_reader.doc('CACM-3134').lucene_document().get('raw')) def test_doc_contents(self): contents = self.index_reader.doc('CACM-3134').contents() self.assertTrue(isinstance(contents, str)) lines = contents.splitlines() self.assertEqual(len(lines), 48) self.assertEqual(lines[0], 'The Use of Normal Multiplication Tables') self.assertEqual(lines[47], '3134\t5\t3134') # Now that we've verified the 'raw', check that alternative ways of fetching give the same results. self.assertEqual(contents, self.index_reader.doc_contents('CACM-3134')) self.assertEqual(contents, self.index_reader.doc('CACM-3134').contents()) self.assertEqual(contents, self.index_reader.doc('CACM-3134').get('contents')) self.assertEqual(contents, self.index_reader.doc('CACM-3134').lucene_document().get('contents')) def test_doc_by_field(self): self.assertEqual(self.index_reader.doc('CACM-3134').docid(), self.index_reader.doc_by_field('id', 'CACM-3134').docid()) def test_bm25_weight(self): self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'inform', analyzer=None, k1=1.2, b=0.75), 1.925014, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'information', k1=1.2, b=0.75), 1.925014, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'retriev', analyzer=None, k1=1.2, b=0.75), 2.496352, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'retrieval', k1=1.2, b=0.75), 2.496352, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'inform', analyzer=None), 2.06514, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'information'), 2.06514, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'retriev', analyzer=None), 2.70038, places=5) self.assertAlmostEqual( self.index_reader.compute_bm25_term_weight('CACM-3134', 'retrieval'), 2.70038, places=5) self.assertAlmostEqual(self.index_reader.compute_bm25_term_weight('CACM-3134', 'fox', analyzer=None), 0., places=5) self.assertAlmostEqual(self.index_reader.compute_bm25_term_weight('CACM-3134', 'fox'), 0., places=5) def test_docid_converstion(self): self.assertEqual(self.index_reader.convert_internal_docid_to_collection_docid(1), 'CACM-0002') self.assertEqual(self.index_reader.convert_collection_docid_to_internal_docid('CACM-0002'), 1) self.assertEqual(self.index_reader.convert_internal_docid_to_collection_docid(1000), 'CACM-1001') self.assertEqual(self.index_reader.convert_collection_docid_to_internal_docid('CACM-1001'), 1000) def test_query_doc_score_default(self): queries = ['information retrieval', 'databases'] for query in queries: hits = self.searcher.search(query) # We're going to verify that the score of each hit is about the same as the output of # compute_query_document_score for i in range(0, len(hits)): self.assertAlmostEqual(hits[i].score, self.index_reader.compute_query_document_score(hits[i].docid, query), places=4) def test_query_doc_score_custom_similarity(self): custom_bm25 = search.LuceneSimilarities.bm25(0.8, 0.2) queries = ['information retrieval', 'databases'] self.searcher.set_bm25(0.8, 0.2) for query in queries: hits = self.searcher.search(query) # We're going to verify that the score of each hit is about the same as the output of # compute_query_document_score for i in range(0, len(hits)): self.assertAlmostEqual(hits[i].score, self.index_reader.compute_query_document_score( hits[i].docid, query, similarity=custom_bm25), places=4) custom_qld = search.LuceneSimilarities.qld(500) self.searcher.set_qld(500) for query in queries: hits = self.searcher.search(query) # We're going to verify that the score of each hit is about the same as the output of # compute_query_document_score for i in range(0, len(hits)): self.assertAlmostEqual(hits[i].score, self.index_reader.compute_query_document_score( hits[i].docid, query, similarity=custom_qld), places=4) def test_index_stats(self): self.assertEqual(3204, self.index_reader.stats()['documents']) self.assertEqual(14363, self.index_reader.stats()['unique_terms']) def test_jstring_encoding(self): # When using pyjnius in a version prior 1.3.0, creating a JString with non-ASCII characters resulted in a # failure. This test simply ensures that a compatible version of pyjnius is used. More details can be found in # the discussion here: https://github.com/castorini/pyserini/issues/770 JString('zoƶlogy') def test_dump_documents_BM25(self): file_path = 'collections/cacm_documents_bm25_dump.jsonl' self.index_reader.dump_documents_BM25(file_path) dump_file = open(file_path, 'r') num_lines = sum(1 for line in dump_file) dump_file.seek(0) assert num_lines == self.index_reader.stats()['documents'] def compare_searcher(query): """Comparing searching with LuceneSearcher to brute-force searching through documents in dump The scores should match. Parameters ---------- query : str The query for search. """ # Search through documents BM25 dump query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) heap = [] # heapq implements a min-heap, we can invert the values to have a max-heap for line in dump_file: doc = json.loads(line) score = 0 for term in query_terms: if term in doc['vector']: score += doc['vector'][term] heapq.heappush(heap, (-1*score, doc['id'])) dump_file.seek(0) # Using LuceneSearcher instead hits = self.searcher.search(query) for i in range(0, 10): top = heapq.heappop(heap) self.assertEqual(hits[i].docid, top[1]) self.assertAlmostEqual(hits[i].score, -1*top[0], places=3) compare_searcher('I am interested in articles written either by Prieve or Udo Pooch') compare_searcher('Performance evaluation and modelling of computer systems') compare_searcher('Addressing schemes for resources in networks; resource addressing in network operating systems') dump_file.close() os.remove(file_path) def test_quantize_weights(self): dump_file_path = 'collections/cacm_documents_bm25_dump.jsonl' quantized_file_path = 'collections/cacm_documents_bm25_dump_quantized.jsonl' self.index_reader.dump_documents_BM25(dump_file_path) self.index_reader.quantize_weights(dump_file_path, quantized_file_path) quantized_weights_file = open(quantized_file_path, 'r') num_lines = sum(1 for line in quantized_weights_file) quantized_weights_file.seek(0) assert num_lines == self.index_reader.stats()['documents'] def compare_searcher_quantized(query, tolerance=1): """Comparing searching with LuceneSearcher to brute-force searching through documents in dump If the weights are quantized the scores will not match but the rankings should still roughly match. Parameters ---------- query : str The query for search. tolerance : int Number of places within which rankings should match i.e. if the ranking of some document with searching through documents in the dump is 2, then with a tolerance of 1 the ranking of the same document with Lucene searcher should be between 1-3. """ query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) heap = [] for line in quantized_weights_file: doc = json.loads(line) score = 0 for term in query_terms: if term in doc['vector']: score += doc['vector'][term] heapq.heappush(heap, (-1*score, doc['id'])) quantized_weights_file.seek(0) hits = self.searcher.search(query) for i in range(0, 10): top = heapq.heappop(heap) match_within_tolerance = False for j in range(tolerance+1): match_within_tolerance = (i-j >= 0 and hits[i-j].docid == top[1]) or (hits[i+j].docid == top[1]) if match_within_tolerance: break self.assertEqual(match_within_tolerance, True) compare_searcher_quantized('I am interested in articles written either by Prieve or Udo Pooch') compare_searcher_quantized('Performance evaluation and modelling of computer systems') compare_searcher_quantized('Addressing schemes for resources in networks; resource addressing in network operating systems') quantized_weights_file.close() os.remove(quantized_file_path) def tearDown(self): os.remove(self.tarball_name) shutil.rmtree(self.index_dir) for f in self.temp_folders: shutil.rmtree(f) if __name__ == '__main__': unittest.main()