Spaces:
Sleeping
Sleeping
add reranker
Browse files- app.py +21 -5
- backend/reranker.py +13 -0
app.py
CHANGED
@@ -11,6 +11,7 @@ from jinja2 import Environment, FileSystemLoader
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
|
|
14 |
|
15 |
|
16 |
TOP_K = int(os.getenv("TOP_K", 4))
|
@@ -34,7 +35,7 @@ def add_text(history, text):
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
-
def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder,
|
38 |
top_k_param = int(top_k_param)
|
39 |
query = history[-1][0]
|
40 |
|
@@ -47,6 +48,11 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
|
|
47 |
|
48 |
#documents = retrieve(query, TOP_K)
|
49 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
document_time = perf_counter() - document_start
|
@@ -121,7 +127,7 @@ with gr.Blocks() as demo:
|
|
121 |
)
|
122 |
cross_encoder = gr.Radio(
|
123 |
choices=[
|
124 |
-
"None"
|
125 |
"BAAI/bge-reranker-large",
|
126 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
127 |
],
|
@@ -137,20 +143,30 @@ with gr.Blocks() as demo:
|
|
137 |
],
|
138 |
value="5",
|
139 |
label='top-K'
|
140 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
143 |
prompt_html = gr.HTML()
|
144 |
# Turn off interactivity while generating if you click
|
145 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
146 |
-
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
|
147 |
|
148 |
# Turn it back on
|
149 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
150 |
|
151 |
# Turn off interactivity while generating if you hit enter
|
152 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
153 |
-
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
|
154 |
|
155 |
# Turn it back on
|
156 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
14 |
+
from backend.reranker import rerank_documents
|
15 |
|
16 |
|
17 |
TOP_K = int(os.getenv("TOP_K", 4))
|
|
|
35 |
return history, gr.Textbox(value="", interactive=False)
|
36 |
|
37 |
|
38 |
+
def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, rerank_topk ):
|
39 |
top_k_param = int(top_k_param)
|
40 |
query = history[-1][0]
|
41 |
|
|
|
48 |
|
49 |
#documents = retrieve(query, TOP_K)
|
50 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
51 |
+
if cross_encoder != "None" and len(documents) > 1:
|
52 |
+
documents = rerank_documents(query, documents, query, top_k_rerank=rerank_topk)
|
53 |
+
#"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
54 |
+
|
55 |
+
|
56 |
|
57 |
|
58 |
document_time = perf_counter() - document_start
|
|
|
127 |
)
|
128 |
cross_encoder = gr.Radio(
|
129 |
choices=[
|
130 |
+
"None",
|
131 |
"BAAI/bge-reranker-large",
|
132 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
133 |
],
|
|
|
143 |
],
|
144 |
value="5",
|
145 |
label='top-K'
|
146 |
+
)
|
147 |
+
rerank_topk = gr.Radio(
|
148 |
+
choices=[
|
149 |
+
"5",
|
150 |
+
"10",
|
151 |
+
"20",
|
152 |
+
"50",
|
153 |
+
],
|
154 |
+
value="5",
|
155 |
+
label='rerank-top-K'
|
156 |
+
)
|
157 |
|
158 |
|
159 |
prompt_html = gr.HTML()
|
160 |
# Turn off interactivity while generating if you click
|
161 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
162 |
+
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
|
163 |
|
164 |
# Turn it back on
|
165 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
166 |
|
167 |
# Turn off interactivity while generating if you hit enter
|
168 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
169 |
+
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
|
170 |
|
171 |
# Turn it back on
|
172 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/reranker.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
|
3 |
+
|
4 |
+
def rerank_documents(ce_model_name, documents, query, top_k_rerank):
|
5 |
+
top_k_rerank = int(top_k_rerank)
|
6 |
+
pairs = []
|
7 |
+
for doc in documents:
|
8 |
+
pairs.append((query, doc))
|
9 |
+
ce_model = CrossEncoder(ce_model_name, max_length=512)
|
10 |
+
scores = ce_model.predict(pairs)
|
11 |
+
#sorted_pairs = [(s, x[1]) for s, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
|
12 |
+
reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
|
13 |
+
return reranked_docs[:top_k_rerank]
|