Spaces:
Runtime error
Runtime error
ivan-savchuk
commited on
Commit
β’
0853141
1
Parent(s):
16fbbdb
update for faiss only
Browse files
app.py
CHANGED
@@ -32,24 +32,32 @@ class DocumentSearch:
|
|
32 |
# loading faiss index
|
33 |
self.index = faiss.read_index(DocumentSearch.idx_path)
|
34 |
# loading sbert cross_encoder
|
35 |
-
self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
|
36 |
|
37 |
def search(self, query: str, k: int) -> list:
|
38 |
# get vector representation of text query
|
39 |
query_vector = self.encoder.encode([query])
|
40 |
# perform search via faiss FlatIP index
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# get answers by index
|
43 |
-
answers = [self.docs[i] for i in indeces[0]]
|
44 |
# prepare inputs for cross encoder
|
45 |
-
model_inputs = [[query, pairs[0]] for pairs in answers]
|
46 |
-
urls = [pairs[1] for pairs in answers]
|
47 |
# get similarity score between query and documents
|
48 |
-
scores = self.cross_encoder.predict(model_inputs, batch_size=1)
|
49 |
# compose results into list of dicts
|
50 |
-
results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
|
51 |
-
|
52 |
-
return
|
|
|
53 |
|
54 |
|
55 |
if __name__ == "__main__":
|
@@ -99,3 +107,7 @@ if __name__ == "__main__":
|
|
99 |
|
100 |
st.markdown("---")
|
101 |
st.markdown("**Author:** Ivan Savchuk. 2022")
|
|
|
|
|
|
|
|
|
|
32 |
# loading faiss index
|
33 |
self.index = faiss.read_index(DocumentSearch.idx_path)
|
34 |
# loading sbert cross_encoder
|
35 |
+
# self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
|
36 |
|
37 |
def search(self, query: str, k: int) -> list:
|
38 |
# get vector representation of text query
|
39 |
query_vector = self.encoder.encode([query])
|
40 |
# perform search via faiss FlatIP index
|
41 |
+
distances, indeces = self.index.search(query_vector, k*10)
|
42 |
+
# get docs by index
|
43 |
+
docs = [self.labels[i] for i in indeces[0]]
|
44 |
+
# get scores by index
|
45 |
+
dists = [dist for dist in distances[0]]
|
46 |
+
|
47 |
+
return[{'doc': doc[0], 'url':, doc[1], 'score': dist} for doc, dist in zip(docs, dists)]
|
48 |
+
##### OLD VERSION WITH CROSS-ENCODER #####
|
49 |
# get answers by index
|
50 |
+
#answers = [self.docs[i] for i in indeces[0]]
|
51 |
# prepare inputs for cross encoder
|
52 |
+
# model_inputs = [[query, pairs[0]] for pairs in answers]
|
53 |
+
# urls = [pairs[1] for pairs in answers]
|
54 |
# get similarity score between query and documents
|
55 |
+
# scores = self.cross_encoder.predict(model_inputs, batch_size=1)
|
56 |
# compose results into list of dicts
|
57 |
+
# results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
|
58 |
+
|
59 |
+
# return results sorted by similarity scores
|
60 |
+
# return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
|
61 |
|
62 |
|
63 |
if __name__ == "__main__":
|
|
|
107 |
|
108 |
st.markdown("---")
|
109 |
st.markdown("**Author:** Ivan Savchuk. 2022")
|
110 |
+
else:
|
111 |
+
st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
|
112 |
+
_**\"How to cure breast cancer?\"**_,\
|
113 |
+
_**\"I have headache, what should I do?\"**_")
|