z00mP commited on
Commit
0102b8a
1 Parent(s): 8663794

add reranker

Browse files
Files changed (2) hide show
  1. app.py +21 -5
  2. backend/reranker.py +13 -0
app.py CHANGED
@@ -11,6 +11,7 @@ from jinja2 import Environment, FileSystemLoader
11
 
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
 
14
 
15
 
16
  TOP_K = int(os.getenv("TOP_K", 4))
@@ -34,7 +35,7 @@ def add_text(history, text):
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
- def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param):
38
  top_k_param = int(top_k_param)
39
  query = history[-1][0]
40
 
@@ -47,6 +48,11 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
47
 
48
  #documents = retrieve(query, TOP_K)
49
  documents = retrieve(query, top_k_param, chunk_table, embedding_model)
 
 
 
 
 
50
 
51
 
52
  document_time = perf_counter() - document_start
@@ -121,7 +127,7 @@ with gr.Blocks() as demo:
121
  )
122
  cross_encoder = gr.Radio(
123
  choices=[
124
- "None"
125
  "BAAI/bge-reranker-large",
126
  "cross-encoder/ms-marco-MiniLM-L-6-v2",
127
  ],
@@ -137,20 +143,30 @@ with gr.Blocks() as demo:
137
  ],
138
  value="5",
139
  label='top-K'
140
- )
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  prompt_html = gr.HTML()
144
  # Turn off interactivity while generating if you click
145
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
146
- bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
147
 
148
  # Turn it back on
149
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
150
 
151
  # Turn off interactivity while generating if you hit enter
152
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
153
- bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
154
 
155
  # Turn it back on
156
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
11
 
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
14
+ from backend.reranker import rerank_documents
15
 
16
 
17
  TOP_K = int(os.getenv("TOP_K", 4))
 
35
  return history, gr.Textbox(value="", interactive=False)
36
 
37
 
38
+ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, rerank_topk ):
39
  top_k_param = int(top_k_param)
40
  query = history[-1][0]
41
 
 
48
 
49
  #documents = retrieve(query, TOP_K)
50
  documents = retrieve(query, top_k_param, chunk_table, embedding_model)
51
+ if cross_encoder != "None" and len(documents) > 1:
52
+ documents = rerank_documents(query, documents, query, top_k_rerank=rerank_topk)
53
+ #"cross-encoder/ms-marco-MiniLM-L-6-v2"
54
+
55
+
56
 
57
 
58
  document_time = perf_counter() - document_start
 
127
  )
128
  cross_encoder = gr.Radio(
129
  choices=[
130
+ "None",
131
  "BAAI/bge-reranker-large",
132
  "cross-encoder/ms-marco-MiniLM-L-6-v2",
133
  ],
 
143
  ],
144
  value="5",
145
  label='top-K'
146
+ )
147
+ rerank_topk = gr.Radio(
148
+ choices=[
149
+ "5",
150
+ "10",
151
+ "20",
152
+ "50",
153
+ ],
154
+ value="5",
155
+ label='rerank-top-K'
156
+ )
157
 
158
 
159
  prompt_html = gr.HTML()
160
  # Turn off interactivity while generating if you click
161
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
162
+ bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
163
 
164
  # Turn it back on
165
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
166
 
167
  # Turn off interactivity while generating if you hit enter
168
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
169
+ bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
170
 
171
  # Turn it back on
172
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
backend/reranker.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+
4
+ def rerank_documents(ce_model_name, documents, query, top_k_rerank):
5
+ top_k_rerank = int(top_k_rerank)
6
+ pairs = []
7
+ for doc in documents:
8
+ pairs.append((query, doc))
9
+ ce_model = CrossEncoder(ce_model_name, max_length=512)
10
+ scores = ce_model.predict(pairs)
11
+ #sorted_pairs = [(s, x[1]) for s, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
12
+ reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
13
+ return reranked_docs[:top_k_rerank]