Jenechek commited on
Commit
c9775bb
1 Parent(s): 1ac805e

add cross-encoder

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. backend/semantic_search.py +12 -3
app.py CHANGED
@@ -34,7 +34,7 @@ def add_text(history, text):
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
- def bot(history, api_kind):
38
  query = history[-1][0]
39
 
40
  if not query:
@@ -44,7 +44,7 @@ def bot(history, api_kind):
44
  # Retrieve documents relevant to query
45
  document_start = perf_counter()
46
 
47
- documents = retrieve(query, TOP_K)
48
 
49
  document_time = perf_counter() - document_start
50
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
@@ -86,12 +86,13 @@ with gr.Blocks() as demo:
86
  )
87
  txt_btn = gr.Button(value="Submit text", scale=1)
88
 
89
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
 
90
 
91
  prompt_html = gr.HTML()
92
  # Turn off interactivity while generating if you click
93
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
94
- bot, [chatbot, api_kind], [chatbot, prompt_html])
95
 
96
  # Turn it back on
97
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
+ def bot(history, api_kind, with_cross_encoder):
38
  query = history[-1][0]
39
 
40
  if not query:
 
44
  # Retrieve documents relevant to query
45
  document_start = perf_counter()
46
 
47
+ documents = retrieve(query, TOP_K, with_cross_encoder)
48
 
49
  document_time = perf_counter() - document_start
50
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
 
86
  )
87
  txt_btn = gr.Button(value="Submit text", scale=1)
88
 
89
+ api_kind = gr.Checkbox(label="Cross-encoder")
90
+ cross_encoder = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
91
 
92
  prompt_html = gr.HTML()
93
  # Turn off interactivity while generating if you click
94
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
95
+ bot, [chatbot, api_kind, cross_encoder], [chatbot, prompt_html])
96
 
97
  # Turn it back on
98
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
backend/semantic_search.py CHANGED
@@ -2,6 +2,7 @@ import lancedb
2
  import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
 
7
  db = lancedb.connect(".lancedb")
@@ -12,13 +13,21 @@ TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
 
15
 
16
 
17
- def retrieve(query, k):
18
  query_vec = retriever.encode(query)
19
  try:
20
- documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
21
- documents = [doc[TEXT_COLUMN] for doc in documents]
 
 
 
 
 
 
 
22
 
23
  return documents
24
 
 
2
  import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
+ from sentence_transformers import CrossEncoder
6
 
7
 
8
  db = lancedb.connect(".lancedb")
 
13
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
14
 
15
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
16
+ cross_encoder = CrossEncoder(os.getenv("RERANK_MODEL"), max_length=512)
17
 
18
 
19
+ def retrieve(query, k, with_cross_encoder=False):
20
  query_vec = retriever.encode(query)
21
  try:
22
+ if not with_cross_encoder:
23
+ documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
24
+ documents = [doc[TEXT_COLUMN] for doc in documents]
25
+ else:
26
+ documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k * 2).to_list()
27
+ scores = cross_encoder.predict([(query, doc[TEXT_COLUMN]) for doc in documents])
28
+ indexed_arr = [(elem, index) for index, elem in enumerate(scores)]
29
+ sorted_arr = sorted(indexed_arr, key=lambda x: x[0], reverse=True)
30
+ documents = [elem for elem, _ in sorted_arr[:k]]
31
 
32
  return documents
33