NetsPresso_QA / scripts /kilt /anserini_retriever.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
5.4 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import multiprocessing
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from pyserini.search.lucene import LuceneSearcher
import kilt.kilt_utils as utils
from kilt.retrievers.base_retriever import Retriever
import jnius
from nltk import bigrams, word_tokenize, SnowballStemmer
from nltk.corpus import stopwords
import string
ent_start_token = "[START_ENT]"
ent_end_token = "[END_ENT]"
STOPWORDS = set(stopwords.words('english') + list(string.punctuation))
stemmer = SnowballStemmer("english")
def parse_hits(hits):
doc_ids = []
doc_scores = []
for hit in hits:
wikipedia_id = hit.docid.split('-')[0]
if wikipedia_id and wikipedia_id not in doc_ids:
doc_ids.append(wikipedia_id)
doc_scores.append(hit.score)
return doc_ids, doc_scores
def _get_predictions_thread(arguments):
id = arguments["id"]
queries_data = arguments["queries_data"]
topk = arguments["topk"]
ranker = arguments["ranker"]
logger = arguments["logger"]
use_bigrams = arguments["use_bigrams"]
stem_bigrams = arguments["stem_bigrams"]
if id == 0:
iter_ = tqdm(queries_data)
else:
iter_ = queries_data
result_doc_ids = []
result_doc_scores = []
result_query_id = []
for query_element in iter_:
query = (
query_element["query"]
.replace(ent_start_token, "")
.replace(ent_end_token, "")
.strip()
)
result_query_id.append(query_element["id"])
doc_ids = []
doc_scores = []
if use_bigrams:
tokens = filter(lambda word: word.lower() not in STOPWORDS, word_tokenize(query))
if stem_bigrams:
tokens = map(stemmer.stem, tokens)
bigram_query = bigrams(tokens)
bigram_query = " ".join(["".join(bigram) for bigram in bigram_query])
query += " " + bigram_query
try:
hits = ranker.search(query, k=topk)
doc_ids, doc_scores = parse_hits(hits)
# doc_ids = [hit.docid for hit in hits]
# doc_scores = [hit.score for hit in hits]
except RuntimeError as e:
if logger:
logger.warning("RuntimeError: {}".format(e))
except jnius.JavaException as e:
if logger:
logger.warning("{query} jnius.JavaException: {}".format(query_element, e))
if 'maxClauseCount' in str(e):
query = " ".join(query.split()[:950])
hits = ranker.search(query, k=topk)
doc_ids, doc_scores = parse_hits(hits)
else:
print(query, str(e))
raise e
# doc_ids = [hit.docid for hit in hits]
# doc_scores = [hit.score for hit in hits]
except Exception as e:
print(query, str(e))
raise e
result_doc_ids.append(doc_ids)
result_doc_scores.append(doc_scores)
return result_doc_ids, result_doc_scores, result_query_id
class Anserini(Retriever):
def __init__(self, name, num_threads, index_dir=None, k1=0.9, b=0.4, use_bigrams=False, stem_bigrams=False):
super().__init__(name)
self.num_threads = min(num_threads, int(multiprocessing.cpu_count()))
# initialize a ranker per thread
self.arguments = []
for id in tqdm(range(self.num_threads)):
ranker = LuceneSearcher(index_dir)
ranker.set_bm25(k1, b)
self.arguments.append(
{
"id": id,
"ranker": ranker,
"use_bigrams": use_bigrams,
"stem_bigrams": stem_bigrams
}
)
def fed_data(self, queries_data, topk, logger=None):
chunked_queries = utils.chunk_it(queries_data, self.num_threads)
for idx, arg in enumerate(self.arguments):
arg["queries_data"] = chunked_queries[idx]
arg["topk"] = topk
arg["logger"] = logger
def run(self):
pool = ThreadPool(self.num_threads)
results = pool.map(_get_predictions_thread, self.arguments)
all_doc_id = []
all_doc_scores = []
all_query_id = []
provenance = {}
for x in results:
i, s, q = x
all_doc_id.extend(i)
all_doc_scores.extend(s)
all_query_id.extend(q)
for query_id, doc_ids in zip(q, i):
provenance[query_id] = []
for d_id in doc_ids:
provenance[query_id].append({"wikipedia_id": str(d_id).strip()})
pool.terminate()
pool.join()
return all_doc_id, all_doc_scores, all_query_id, provenance