Spaces:
Runtime error
Runtime error
from haystack.document_stores import InMemoryDocumentStore | |
from haystack.nodes.retriever import TfidfRetriever | |
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline | |
from haystack.nodes.retriever import EmbeddingRetriever | |
import pickle | |
from pprint import pprint | |
dutch_datset_name = 'Partisan news 2019 (dutch)' | |
german_datset_name = 'CDU election program 2021' | |
class ExportableInMemoryDocumentStore(InMemoryDocumentStore): | |
""" | |
Wrapper class around the InMemoryDocumentStore. | |
When the application is deployed to Huggingface Spaces there will be no GPU available. | |
We need to load pre-calculated data into the InMemoryDocumentStore. | |
""" | |
def export(self, file_name='in_memory_store.pkl'): | |
with open(file_name, 'wb') as f: | |
pickle.dump(self.indexes, f) | |
def load_data(self, file_name='in_memory_store.pkl'): | |
with open(file_name, 'rb') as f: | |
self.indexes = pickle.load(f) | |
class SearchEngine(): | |
def __init__(self, document_store_name_base, document_store_name_adpated, | |
adapted_retriever_path): | |
self.document_store = ExportableInMemoryDocumentStore(similarity='cosine') | |
self.document_store.load_data(document_store_name_base) | |
self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine') | |
self.document_store_adapted.load_data(document_store_name_adpated) | |
self.retriever = TfidfRetriever(document_store=self.document_store) | |
self.base_dense_retriever = EmbeddingRetriever( | |
document_store=self.document_store, | |
embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
model_format='sentence_transformers' | |
) | |
self.fine_tuned_retriever = EmbeddingRetriever( | |
document_store=self.document_store_adapted, | |
embedding_model=adapted_retriever_path, | |
model_format='sentence_transformers' | |
) | |
def sparse_retrieval(self, query): | |
"""Sparse retrieval pipeline""" | |
scores = self.retriever._calc_scores(query) | |
p_retrieval = DocumentSearchPipeline(self.retriever) | |
documents = p_retrieval.run(query=query) | |
documents['documents'][0].score = list(scores[0].values())[0] | |
return documents | |
def dense_retrieval(self, query, retriever='base'): | |
if retriever == 'base': | |
p_retrieval = DocumentSearchPipeline(self.base_dense_retriever) | |
return p_retrieval.run(query=query) | |
if retriever == 'adapted': | |
p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever) | |
return p_retrieval.run(query=query) | |
def do_search(self, query): | |
sparse_result = self.sparse_retrieval(query)['documents'][0] | |
dense_base_result = self.dense_retrieval(query, 'base')['documents'][0] | |
dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0] | |
return sparse_result, dense_base_result, dense_adapted_result | |
dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', | |
'dutch-article-retriever') | |
german_search_engine = SearchEngine('documentstore_german-election-idx.pkl', | |
'documentstore_german-election-idx_adapted.pkl', | |
'adapted-retriever') | |
def do_search(query, dataset): | |
if dataset == german_datset_name: | |
return german_search_engine.do_search(query) | |
else: | |
return dutch_search_engine.do_search(query) | |
if __name__ == '__main__': | |
search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', | |
'dutch-article-retriever') | |
query = 'Kindergarten' | |
result = search_engine.do_search(query) | |
pprint(result) | |