semantic-search-demo / retriever.py
mrchtr's picture
Add dutch partisan news dataset
8bd9363
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
import pickle
from pprint import pprint
dutch_datset_name = 'Partisan news 2019 (dutch)'
german_datset_name = 'CDU election program 2021'
class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
"""
Wrapper class around the InMemoryDocumentStore.
When the application is deployed to Huggingface Spaces there will be no GPU available.
We need to load pre-calculated data into the InMemoryDocumentStore.
"""
def export(self, file_name='in_memory_store.pkl'):
with open(file_name, 'wb') as f:
pickle.dump(self.indexes, f)
def load_data(self, file_name='in_memory_store.pkl'):
with open(file_name, 'rb') as f:
self.indexes = pickle.load(f)
class SearchEngine():
def __init__(self, document_store_name_base, document_store_name_adpated,
adapted_retriever_path):
self.document_store = ExportableInMemoryDocumentStore(similarity='cosine')
self.document_store.load_data(document_store_name_base)
self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine')
self.document_store_adapted.load_data(document_store_name_adpated)
self.retriever = TfidfRetriever(document_store=self.document_store)
self.base_dense_retriever = EmbeddingRetriever(
document_store=self.document_store,
embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
model_format='sentence_transformers'
)
self.fine_tuned_retriever = EmbeddingRetriever(
document_store=self.document_store_adapted,
embedding_model=adapted_retriever_path,
model_format='sentence_transformers'
)
def sparse_retrieval(self, query):
"""Sparse retrieval pipeline"""
scores = self.retriever._calc_scores(query)
p_retrieval = DocumentSearchPipeline(self.retriever)
documents = p_retrieval.run(query=query)
documents['documents'][0].score = list(scores[0].values())[0]
return documents
def dense_retrieval(self, query, retriever='base'):
if retriever == 'base':
p_retrieval = DocumentSearchPipeline(self.base_dense_retriever)
return p_retrieval.run(query=query)
if retriever == 'adapted':
p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever)
return p_retrieval.run(query=query)
def do_search(self, query):
sparse_result = self.sparse_retrieval(query)['documents'][0]
dense_base_result = self.dense_retrieval(query, 'base')['documents'][0]
dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0]
return sparse_result, dense_base_result, dense_adapted_result
dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
'dutch-article-retriever')
german_search_engine = SearchEngine('documentstore_german-election-idx.pkl',
'documentstore_german-election-idx_adapted.pkl',
'adapted-retriever')
def do_search(query, dataset):
if dataset == german_datset_name:
return german_search_engine.do_search(query)
else:
return dutch_search_engine.do_search(query)
if __name__ == '__main__':
search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl',
'dutch-article-retriever')
query = 'Kindergarten'
result = search_engine.do_search(query)
pprint(result)