Spaces:
Runtime error
Runtime error
File size: 4,372 Bytes
31c5069 bf7dfb8 31c5069 4793e50 bf7dfb8 31c5069 bf7dfb8 31c5069 bf7dfb8 31c5069 bf7dfb8 31c5069 bf7dfb8 31c5069 bf7dfb8 31c5069 bf7dfb8 31c5069 2613971 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import streamlit as st
import json
import time
import faiss
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
class DocumentSearch:
'''
This class is dedicated to
perform semantic document search
based on previously trained:
faiss: index
sbert: encoder
sbert: cross_encoder
'''
# we mention pass to every file that needed to run models
# and search over our data
enc_path = "ivan-savchuk/msmarco-distilbert-dot-v5-tuned-full-v1"
idx_path = "idx_vectors.index"
cross_enc_path = "ivan-savchuk/cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
docs_path = "docs.json"
def __init__(self):
# loading docs and corresponding urls
with open(DocumentSearch.docs_path, 'r') as json_file:
self.docs = json.load(json_file)
# loading sbert encoder model
self.encoder = SentenceTransformer(DocumentSearch.enc_path)
# loading faiss index
self.index = faiss.read_index(DocumentSearch.idx_path)
# loading sbert cross_encoder
# self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
def search(self, query: str, k: int) -> list:
# get vector representation of text query
query_vector = self.encoder.encode([query])
# perform search via faiss FlatIP index
distances, indeces = self.index.search(query_vector, k*10)
# get docs by index
res_docs = [self.docs[i] for i in indeces[0]]
# get scores by index
dists = [dist for dist in distances[0]]
return[{'doc': doc[0], 'url': doc[1], 'score': dist} for doc, dist in zip(res_docs, dists)][:k]
##### OLD VERSION WITH CROSS-ENCODER #####
# get answers by index
#answers = [self.docs[i] for i in indeces[0]]
# prepare inputs for cross encoder
# model_inputs = [[query, pairs[0]] for pairs in answers]
# urls = [pairs[1] for pairs in answers]
# get similarity score between query and documents
# scores = self.cross_encoder.predict(model_inputs, batch_size=1)
# compose results into list of dicts
# results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
# return results sorted by similarity scores
# return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
if __name__ == "__main__":
# get instance of DocumentSearch class
surfer = DocumentSearch()
# streamlit part starts here with title
title = """
<h1 style='
text-align: center;
color: #3CB371'>
Medical Search
</h1>
"""
st.markdown(title, unsafe_allow_html=True)
# input form
with st.form("my_form"):
# here we have input space
query = st.text_input("Enter query about our Medical Data",
placeholder="Type query here...",
max_chars=200)
# Every form must have a submit button.
submitted = st.form_submit_button("Search")
# on submit we execute search
if(submitted):
# set start time
stt = time.time()
# retrieve top 5 documents
results = surfer.search(query, k=10)
# set endtime
ent = time.time()
# measure resulting time
elapsed_time = round(ent - stt, 2)
# show which query was entered, and what was searching time
st.write(f"**Results Related to:** \"{query}\" ({elapsed_time} sec.)")
# then we use loop to show results
for i, answer in enumerate(results):
# answer starts with header
st.subheader(f"Answer {i+1}")
# cropped answer
doc = answer["doc"][:250] + "..."
# and url to the full answer
url = answer["url"]
# then we display it
st.markdown(f'{doc}\n[**Read More**]({url})\n', unsafe_allow_html=True)
st.markdown("---")
st.markdown("**Author:** Ivan Savchuk. 2022")
else:
st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
_**\"How to cure breast cancer?\"**_,\
_**\"I have headache, what should I do?\"**_") |