File size: 4,373 Bytes
b951bdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eec8cd
 
 
 
 
 
 
 
b951bdb
a0c9518
 
 
b951bdb
7eec8cd
b951bdb
7eec8cd
b951bdb
0853141
b951bdb
 
 
 
 
0853141
 
42dd8a2
0853141
 
 
0f933d9
0853141
b951bdb
0853141
b951bdb
0853141
 
b951bdb
0853141
b951bdb
0853141
 
 
 
b951bdb
 
a8e52be
 
7eec8cd
b951bdb
2e2bd12
 
 
 
 
 
 
 
8362484
 
 
15936af
14befa6
 
8362484
 
 
b951bdb
8362484
b951bdb
 
 
0f933d9
b951bdb
 
 
 
8362484
 
dfa96a4
8362484
d0d3b14
8362484
 
 
0f933d9
8362484
 
 
 
b951bdb
cbd24a5
8362484
 
0853141
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import json
import time
import faiss
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder


class DocumentSearch:
    '''
        This class is dedicated to
        perform semantic document search
        based on previously trained:
        faiss: index
        sbert: encoder
        sbert: cross_encoder
    '''
    # we mention pass to every file that needed to run models
    # and search over our data
    enc_path = "ivan-savchuk/msmarco-distilbert-dot-v5-tuned-full-v1"
    idx_path = "idx_vectors.index"
    cross_enc_path = "ivan-savchuk/cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
    docs_path = "docs.json"
    
    def __init__(self):
        # loading docs and corresponding urls
        with open(DocumentSearch.docs_path, 'r') as json_file:
            self.docs = json.load(json_file)
        
        # loading sbert encoder model
        self.encoder = SentenceTransformer(DocumentSearch.enc_path)
        # loading faiss index
        self.index = faiss.read_index(DocumentSearch.idx_path)
        # loading sbert cross_encoder
        # self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)

    def search(self, query: str, k: int) -> list:
        # get vector representation of text query
        query_vector = self.encoder.encode([query])
        # perform search via faiss FlatIP index
        distances, indeces = self.index.search(query_vector, k*10)
        # get docs by index
        res_docs = [self.docs[i] for i in indeces[0]]
        # get scores by index
        dists = [dist for dist in distances[0]]
        
        return[{'doc': doc[0], 'url': doc[1], 'score': dist} for doc, dist in zip(res_docs, dists)][:k]
        ##### OLD VERSION WITH CROSS-ENCODER #####
        # get answers by index
        #answers = [self.docs[i] for i in indeces[0]]
        # prepare inputs for cross encoder
        # model_inputs = [[query, pairs[0]] for pairs in answers]
        # urls = [pairs[1] for pairs in answers]
        # get similarity score between query and documents
        # scores = self.cross_encoder.predict(model_inputs, batch_size=1)
        # compose results into list of dicts
        # results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
        
        # return results sorted by similarity scores
        # return sorted(results, key=lambda x: x['score'], reverse=True)[:k]


if __name__ == "__main__":
    # get instance of DocumentSearch class
    surfer = DocumentSearch()
    # streamlit part starts here with title
    title = """
    <h1 style='
    text-align: center;
    color: #3CB371'>
    Medical Search
    </h1>
    """
    st.markdown(title, unsafe_allow_html=True)
    # input form
    with st.form("my_form"):
        # here we have input space
        query = st.text_input("Enter query about our Medical Data",
                              placeholder="Type query here...",
                              max_chars=200)
        # Every form must have a submit button.
        submitted = st.form_submit_button("Search")
    
    # on submit we execute search
    if(submitted):
        # set start time
        stt = time.time()
        # retrieve top 5 documents
        results = surfer.search(query, k=10)
        # set endtime
        ent = time.time()
        # measure resulting time
        elapsed_time = round(ent - stt, 2)
        
        # show which query was entered, and what was searching time
        st.write(f"**Results Related to:** \"{query}\" ({elapsed_time} sec.)")
        # then we use loop to show results
        for i, answer in enumerate(results):
            # answer starts with header
            st.subheader(f"Answer {i+1}")
            # cropped answer
            doc = answer["doc"][:250] + "..."
            # and url to the full answer
            url = answer["url"]
            # then we display it
            st.markdown(f'{doc}\n[**Read More**]({url})\n', unsafe_allow_html=True)

            
        st.markdown("---")
        st.markdown("**Author:** Ivan Savchuk. 2022")
    else:
        st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
                    _**\"How to cure breast cancer?\"**_,\
                    _**\"I have headache, what should I do?\"**_")