Spaces:
Runtime error
Runtime error
add cross-encoder
Browse files- app.py +5 -4
- backend/semantic_search.py +12 -3
app.py
CHANGED
@@ -34,7 +34,7 @@ def add_text(history, text):
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
-
def bot(history, api_kind):
|
38 |
query = history[-1][0]
|
39 |
|
40 |
if not query:
|
@@ -44,7 +44,7 @@ def bot(history, api_kind):
|
|
44 |
# Retrieve documents relevant to query
|
45 |
document_start = perf_counter()
|
46 |
|
47 |
-
documents = retrieve(query, TOP_K)
|
48 |
|
49 |
document_time = perf_counter() - document_start
|
50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
@@ -86,12 +86,13 @@ with gr.Blocks() as demo:
|
|
86 |
)
|
87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
88 |
|
89 |
-
api_kind = gr.
|
|
|
90 |
|
91 |
prompt_html = gr.HTML()
|
92 |
# Turn off interactivity while generating if you click
|
93 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
94 |
-
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
95 |
|
96 |
# Turn it back on
|
97 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
+
def bot(history, api_kind, with_cross_encoder):
|
38 |
query = history[-1][0]
|
39 |
|
40 |
if not query:
|
|
|
44 |
# Retrieve documents relevant to query
|
45 |
document_start = perf_counter()
|
46 |
|
47 |
+
documents = retrieve(query, TOP_K, with_cross_encoder)
|
48 |
|
49 |
document_time = perf_counter() - document_start
|
50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
86 |
)
|
87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
88 |
|
89 |
+
api_kind = gr.Checkbox(label="Cross-encoder")
|
90 |
+
cross_encoder = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
|
91 |
|
92 |
prompt_html = gr.HTML()
|
93 |
# Turn off interactivity while generating if you click
|
94 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
95 |
+
bot, [chatbot, api_kind, cross_encoder], [chatbot, prompt_html])
|
96 |
|
97 |
# Turn it back on
|
98 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/semantic_search.py
CHANGED
@@ -2,6 +2,7 @@ import lancedb
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
5 |
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
@@ -12,13 +13,21 @@ TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
|
12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
13 |
|
14 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
|
|
15 |
|
16 |
|
17 |
-
def retrieve(query, k):
|
18 |
query_vec = retriever.encode(query)
|
19 |
try:
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
return documents
|
24 |
|
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
|
8 |
db = lancedb.connect(".lancedb")
|
|
|
13 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
14 |
|
15 |
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
16 |
+
cross_encoder = CrossEncoder(os.getenv("RERANK_MODEL"), max_length=512)
|
17 |
|
18 |
|
19 |
+
def retrieve(query, k, with_cross_encoder=False):
|
20 |
query_vec = retriever.encode(query)
|
21 |
try:
|
22 |
+
if not with_cross_encoder:
|
23 |
+
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
24 |
+
documents = [doc[TEXT_COLUMN] for doc in documents]
|
25 |
+
else:
|
26 |
+
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k * 2).to_list()
|
27 |
+
scores = cross_encoder.predict([(query, doc[TEXT_COLUMN]) for doc in documents])
|
28 |
+
indexed_arr = [(elem, index) for index, elem in enumerate(scores)]
|
29 |
+
sorted_arr = sorted(indexed_arr, key=lambda x: x[0], reverse=True)
|
30 |
+
documents = [elem for elem, _ in sorted_arr[:k]]
|
31 |
|
32 |
return documents
|
33 |
|