Spaces:
Runtime error
Runtime error
File size: 2,465 Bytes
10641ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import convert_files_to_docs
from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes import FARMReader
import pickle
from pprint import pprint
class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
"""
Wrapper class around the InMemoryDocumentStore.
When the application is deployed to Huggingface Spaces there will be no GPU available.
We need to load pre-calculated data into the InMemoryDocumentStore.
"""
def export(self, file_name='in_memory_store.pkl'):
with open(file_name, 'wb') as f:
pickle.dump(self.indexes, f)
def load_data(self, file_name='in_memory_store.pkl'):
with open(file_name, 'rb') as f:
self.indexes = pickle.load(f)
document_store = ExportableInMemoryDocumentStore(similarity='cosine')
document_store.load_data('documentstore_german-election-idx.pkl')
retriever = TfidfRetriever(document_store=document_store)
base_dense_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
model_format='sentence_transformers'
)
fine_tuned_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model='./adapted-retriever',
model_format='sentence_transformers'
)
def sparse_retrieval(query):
"""Sparse retrieval pipeline"""
p_retrieval = DocumentSearchPipeline(retriever)
return p_retrieval.run(query=query)
def dense_retrieval(query, retriever='base'):
if retriever == 'base':
p_retrieval = DocumentSearchPipeline(base_dense_retriever)
elif retriever == 'adapted':
p_retrieval = DocumentSearchPipeline(fine_tuned_retriever)
else:
return None
return p_retrieval.run(query=query)
def do_search(query):
sparse_result = sparse_retrieval(query)['documents'][0].content
dense_base_result = dense_retrieval(query, retriever='base')['documents'][0].content
dense_adapted_result = dense_retrieval(query, retriever='adapted')['documents'][0].content
return sparse_result, dense_base_result, dense_adapted_result
if __name__ == '__main__':
query = 'Klimawandel stoppen?'
result = do_search(query)
pprint(result)
|