Omar Solano commited on
Commit
cbb80f5
Β·
1 Parent(s): 65b9328

add cohere reranking

Browse files
scripts/custom_retriever.py CHANGED
@@ -6,6 +6,7 @@ import logfire
6
  from llama_index.core import QueryBundle
7
  from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
8
  from llama_index.core.schema import NodeWithScore, TextNode
 
9
 
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(level=logging.INFO)
@@ -71,4 +72,8 @@ class CustomRetriever(BaseRetriever):
71
  else:
72
  nodes_context.append(node)
73
 
 
 
 
 
74
  return nodes_context
 
6
  from llama_index.core import QueryBundle
7
  from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
8
  from llama_index.core.schema import NodeWithScore, TextNode
9
+ from llama_index.postprocessor.cohere_rerank import CohereRerank
10
 
11
  logger = logging.getLogger(__name__)
12
  logging.basicConfig(level=logging.INFO)
 
72
  else:
73
  nodes_context.append(node)
74
 
75
+ reranker = CohereRerank(top_n=5, model="rerank-english-v3.0")
76
+ nodes_context = reranker.postprocess_nodes(nodes_context, query_bundle)
77
+ logfire.info(f"Cohere raranking to {len(nodes_context)} nodes")
78
+
79
  return nodes_context
scripts/gradio-ui.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pickle
3
 
@@ -22,6 +23,9 @@ from tutor_prompts import system_message_openai_agent
22
 
23
  load_dotenv()
24
 
 
 
 
25
  logfire.configure()
26
 
27
 
@@ -67,11 +71,10 @@ AVAILABLE_SOURCES = [
67
  ]
68
 
69
 
70
- # # Initialize MongoDB
71
  # mongo_db = (
72
  # init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
73
  # if MONGODB_URI
74
- # else logger.warning("No mongodb uri found, you will not be able to save data.")
75
  # )
76
 
77
 
@@ -223,11 +226,40 @@ def generate_completion(
223
 
224
 
225
  def vote(data: gr.LikeData):
 
226
  if data.liked:
227
  print("You upvoted this response: " + data.value["value"])
228
  else:
229
  print("You downvoted this response: " + data.value["value"])
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
233
  sources = gr.CheckboxGroup(
 
1
+ import logging
2
  import os
3
  import pickle
4
 
 
23
 
24
  load_dotenv()
25
 
26
+ logger = logging.getLogger(__name__)
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger("httpx").setLevel(logging.WARNING)
29
  logfire.configure()
30
 
31
 
 
71
  ]
72
 
73
 
 
74
  # mongo_db = (
75
  # init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
76
  # if MONGODB_URI
77
+ # else logfire.warn("No mongodb uri found, you will not be able to save data.")
78
  # )
79
 
80
 
 
226
 
227
 
228
  def vote(data: gr.LikeData):
229
+ collection = "liked_data-test"
230
  if data.liked:
231
  print("You upvoted this response: " + data.value["value"])
232
  else:
233
  print("You downvoted this response: " + data.value["value"])
234
 
235
+ # completion_json["liked"] = like_data.liked
236
+ # logger.info(f"User reported {like_data.liked=}")
237
+
238
+ # try:
239
+ # cfg.mongo_db[collection].insert_one(completion_json)
240
+ # except:
241
+ # logger.info("Something went wrong logging")
242
+
243
+
244
+ # def save_completion(completion: Completion, history):
245
+ # collection = "completion_data-hf"
246
+
247
+ # # Convert completion to JSON and ignore certain columns
248
+ # completion_json = completion.to_json(
249
+ # columns_to_ignore=["embedding", "similarity", "similarity_to_answer"]
250
+ # )
251
+
252
+ # # Add the current date and time to the JSON
253
+ # completion_json["timestamp"] = datetime.utcnow().isoformat()
254
+ # completion_json["history"] = history
255
+ # completion_json["history_len"] = len(history)
256
+
257
+ # try:
258
+ # cfg.mongo_db[collection].insert_one(completion_json)
259
+ # logger.info("Completion saved to db")
260
+ # except Exception as e:
261
+ # logger.info(f"Something went wrong logging completion to db: {e}")
262
+
263
 
264
  accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
265
  sources = gr.CheckboxGroup(