Spaces:
Runtime error
Runtime error
File size: 4,322 Bytes
b951bdb 7eec8cd b951bdb a0c9518 b951bdb 7eec8cd b951bdb 7eec8cd b951bdb 0853141 b951bdb 0853141 b951bdb 0853141 b951bdb 0853141 b951bdb 0853141 b951bdb 0853141 b951bdb a8e52be 7eec8cd b951bdb 2e2bd12 8362484 b951bdb 8362484 b951bdb 8362484 dfa96a4 8362484 d0d3b14 8362484 b951bdb cbd24a5 8362484 0853141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import streamlit as st
import json
import time
import faiss
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
class DocumentSearch:
'''
This class is dedicated to
perform semantic document search
based on previously trained:
faiss: index
sbert: encoder
sbert: cross_encoder
'''
# we mention pass to every file that needed to run models
# and search over our data
enc_path = "ivan-savchuk/msmarco-distilbert-dot-v5-tuned-full-v1"
idx_path = "idx_vectors.index"
cross_enc_path = "ivan-savchuk/cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
docs_path = "docs.json"
def __init__(self):
# loading docs and corresponding urls
with open(DocumentSearch.docs_path, 'r') as json_file:
self.docs = json.load(json_file)
# loading sbert encoder model
self.encoder = SentenceTransformer(DocumentSearch.enc_path)
# loading faiss index
self.index = faiss.read_index(DocumentSearch.idx_path)
# loading sbert cross_encoder
# self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
def search(self, query: str, k: int) -> list:
# get vector representation of text query
query_vector = self.encoder.encode([query])
# perform search via faiss FlatIP index
distances, indeces = self.index.search(query_vector, k*10)
# get docs by index
docs = [self.labels[i] for i in indeces[0]]
# get scores by index
dists = [dist for dist in distances[0]]
return[{'doc': doc[0], 'url':, doc[1], 'score': dist} for doc, dist in zip(docs, dists)]
##### OLD VERSION WITH CROSS-ENCODER #####
# get answers by index
#answers = [self.docs[i] for i in indeces[0]]
# prepare inputs for cross encoder
# model_inputs = [[query, pairs[0]] for pairs in answers]
# urls = [pairs[1] for pairs in answers]
# get similarity score between query and documents
# scores = self.cross_encoder.predict(model_inputs, batch_size=1)
# compose results into list of dicts
# results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
# return results sorted by similarity scores
# return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
if __name__ == "__main__":
# get instance of DocumentSearch class
surfer = DocumentSearch()
# streamlit part starts here with title
title = """
<h1 style='
text-align: center;
color: #3CB371'>
Medical Search
</h1>
"""
st.markdown(title, unsafe_allow_html=True)
# input form
with st.form("my_form"):
# here we have input space
query = st.text_input("Enter any query about our medical data",
placeholder="Type query here...")
# Every form must have a submit button.
submitted = st.form_submit_button("Search")
# on submit we execute search
if(submitted):
# set start time
stt = time.time()
# retrieve top 5 documents
results = surfer.search(query, k=5)
# set endtime
ent = time.time()
# measure resulting time
elapsed_time = round(ent - stt, 2)
# show which query was entered, and what was searching time
st.write(f"**Results Related to:** \"{query}\" ({elapsed_time} sec.)")
# then we use loop to show results
for i, answer in enumerate(results):
# answer starts with header
st.subheader(f"Answer {i+1}")
# cropped answer
doc = answer["doc"][:150] + "..."
# and url to the full answer
url = answer["url"]
# then we display it
st.markdown(f'{doc}\n[**Read More**]({url})\n', unsafe_allow_html=True)
st.markdown("---")
st.markdown("**Author:** Ivan Savchuk. 2022")
else:
st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
_**\"How to cure breast cancer?\"**_,\
_**\"I have headache, what should I do?\"**_")
|