from haystack.document_stores import InMemoryDocumentStore from haystack.nodes.retriever import TfidfRetriever from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline from haystack.nodes.retriever import EmbeddingRetriever import pickle from pprint import pprint dutch_datset_name = 'Partisan news 2019 (dutch)' german_datset_name = 'CDU election program 2021' class ExportableInMemoryDocumentStore(InMemoryDocumentStore): """ Wrapper class around the InMemoryDocumentStore. When the application is deployed to Huggingface Spaces there will be no GPU available. We need to load pre-calculated data into the InMemoryDocumentStore. """ def export(self, file_name='in_memory_store.pkl'): with open(file_name, 'wb') as f: pickle.dump(self.indexes, f) def load_data(self, file_name='in_memory_store.pkl'): with open(file_name, 'rb') as f: self.indexes = pickle.load(f) class SearchEngine(): def __init__(self, document_store_name_base, document_store_name_adpated, adapted_retriever_path): self.document_store = ExportableInMemoryDocumentStore(similarity='cosine') self.document_store.load_data(document_store_name_base) self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine') self.document_store_adapted.load_data(document_store_name_adpated) self.retriever = TfidfRetriever(document_store=self.document_store) self.base_dense_retriever = EmbeddingRetriever( document_store=self.document_store, embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', model_format='sentence_transformers' ) self.fine_tuned_retriever = EmbeddingRetriever( document_store=self.document_store_adapted, embedding_model=adapted_retriever_path, model_format='sentence_transformers' ) def sparse_retrieval(self, query): """Sparse retrieval pipeline""" scores = self.retriever._calc_scores(query) p_retrieval = DocumentSearchPipeline(self.retriever) documents = p_retrieval.run(query=query) documents['documents'][0].score = list(scores[0].values())[0] return documents def dense_retrieval(self, query, retriever='base'): if retriever == 'base': p_retrieval = DocumentSearchPipeline(self.base_dense_retriever) return p_retrieval.run(query=query) if retriever == 'adapted': p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever) return p_retrieval.run(query=query) def do_search(self, query): sparse_result = self.sparse_retrieval(query)['documents'][0] dense_base_result = self.dense_retrieval(query, 'base')['documents'][0] dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0] return sparse_result, dense_base_result, dense_adapted_result dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', 'dutch-article-retriever') german_search_engine = SearchEngine('documentstore_german-election-idx.pkl', 'documentstore_german-election-idx_adapted.pkl', 'adapted-retriever') def do_search(query, dataset): if dataset == german_datset_name: return german_search_engine.do_search(query) else: return dutch_search_engine.do_search(query) if __name__ == '__main__': search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', 'dutch-article-retriever') query = 'Kindergarten' result = search_engine.do_search(query) pprint(result)