fizban99 commited on
Commit
eaee63c
1 Parent(s): 2da89ac

reranking added

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +10 -2
  3. simiandb.py +2 -2
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ *.pyc
app.py CHANGED
@@ -7,17 +7,25 @@ Created on Wed Mar 22 19:59:54 2023
7
  import gradio as gr
8
  from simiandb import Simiandb
9
  from langchain.embeddings import HuggingFaceEmbeddings
 
10
 
11
 
12
 
13
 
14
  model_name = "all-MiniLM-L6-v2"
15
  hf = HuggingFaceEmbeddings(model_name=model_name)
 
16
 
17
  documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
18
 
19
  def search(query):
20
- return documentdb.similarity_search(query)
 
 
 
 
21
 
22
  iface = gr.Interface(fn=search, inputs="text", outputs="text")
23
- iface.launch()
 
 
 
7
  import gradio as gr
8
  from simiandb import Simiandb
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
+ from sentence_transformers import CrossEncoder
11
 
12
 
13
 
14
 
15
  model_name = "all-MiniLM-L6-v2"
16
  hf = HuggingFaceEmbeddings(model_name=model_name)
17
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
18
 
19
  documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
20
 
21
  def search(query):
22
+ hits = documentdb.similarity_search(query)
23
+ cross_inp = [[query, hit] for hit in hits]
24
+ cross_scores = cross_encoder.predict(cross_inp)
25
+ hits = [hit for _, hit in sorted(zip(cross_scores, hits), reverse=True)]
26
+ return hits[0]
27
 
28
  iface = gr.Interface(fn=search, inputs="text", outputs="text")
29
+ iface.launch()
30
+
31
+ #print(search("what is the balloon boy hoax"))
simiandb.py CHANGED
@@ -178,7 +178,7 @@ class Simiandb():
178
  batch = self._vector_table.chunkshape[0]*25
179
  res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
180
  end = 0
181
- a = time()
182
  while end!=count:
183
  end += batch
184
  end = end if end <= count else count
@@ -189,7 +189,7 @@ class Simiandb():
189
 
190
  indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
191
  indices = indices[np.argsort(res[indices])[::-1]]
192
- print(time() -a)
193
  return indices
194
 
195
 
 
178
  batch = self._vector_table.chunkshape[0]*25
179
  res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
180
  end = 0
181
+
182
  while end!=count:
183
  end += batch
184
  end = end if end <= count else count
 
189
 
190
  indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
191
  indices = indices[np.argsort(res[indices])[::-1]]
192
+
193
  return indices
194
 
195