File size: 2,465 Bytes
10641ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import convert_files_to_docs
from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes import FARMReader
import pickle
from pprint import pprint

class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
    """
    Wrapper class around the InMemoryDocumentStore.
    When the application is deployed to Huggingface Spaces there will be no GPU available.
    We need to load pre-calculated data into the InMemoryDocumentStore.
    """
    def export(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'wb') as f:
            pickle.dump(self.indexes, f)

    def load_data(self, file_name='in_memory_store.pkl'):
        with open(file_name, 'rb') as f:
            self.indexes = pickle.load(f)



document_store = ExportableInMemoryDocumentStore(similarity='cosine')
document_store.load_data('documentstore_german-election-idx.pkl')

retriever = TfidfRetriever(document_store=document_store)
base_dense_retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
        model_format='sentence_transformers'
    )

fine_tuned_retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model='./adapted-retriever',
        model_format='sentence_transformers'
)

def sparse_retrieval(query):
    """Sparse retrieval pipeline"""
    p_retrieval = DocumentSearchPipeline(retriever)
    return p_retrieval.run(query=query)

def dense_retrieval(query, retriever='base'):
    if retriever == 'base':
        p_retrieval = DocumentSearchPipeline(base_dense_retriever)
    elif retriever == 'adapted':
        p_retrieval = DocumentSearchPipeline(fine_tuned_retriever)
    else:
        return None

    return p_retrieval.run(query=query)


def do_search(query):
    sparse_result = sparse_retrieval(query)['documents'][0].content
    dense_base_result = dense_retrieval(query, retriever='base')['documents'][0].content
    dense_adapted_result = dense_retrieval(query, retriever='adapted')['documents'][0].content
    return sparse_result, dense_base_result, dense_adapted_result

if __name__ == '__main__':
    query = 'Klimawandel stoppen?'
    result = do_search(query)
    pprint(result)