melk2025 commited on
Commit
e8da5e2
·
verified ·
1 Parent(s): 23e5a90

added history

Browse files
Files changed (1) hide show
  1. app.py +30 -122
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import chromadb
2
  import pandas as pd
3
  from sentence_transformers import SentenceTransformer
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -8,7 +8,7 @@ from openai import OpenAI
8
  import numpy as np
9
  import requests
10
  import chromadb
11
- from chromadb import Client
12
  from sentence_transformers import SentenceTransformer, util
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from chromadb import Client
@@ -21,11 +21,8 @@ import requests
21
  import time
22
  import tempfile
23
 
24
- #HF_TOKEN = os.getenv("HF_TOKEN")
25
-
26
  API_KEY = os.environ.get("OPENROUTER_API_KEY")
27
 
28
-
29
  # Load the Excel file
30
  df = pd.read_excel("web_documents.xlsx", engine='openpyxl')
31
 
@@ -38,14 +35,8 @@ collection = client.get_or_create_collection(
38
  metadata={"hnsw:space": "cosine"}
39
  )
40
 
41
- # Load the embedding model new model
42
- #embedding_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
43
- #embedding_model = SentenceTransformer("BAAI/bge-m3")
44
- embedding_model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
45
-
46
-
47
-
48
-
49
 
50
  # Initialize the text splitter
51
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=150)
@@ -78,21 +69,12 @@ for idx, row in df.iterrows():
78
 
79
  # ---------------------- Config ----------------------
80
  SIMILARITY_THRESHOLD = 0.80
81
-
82
- client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # remplace par ta clé OpenRouter
83
-
84
 
85
  # ---------------------- Models ----------------------
86
- # High-accuracy model for semantic search
87
- #semantic_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
88
- #semantic_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
89
- semantic_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
90
-
91
 
92
- # For ChromaDB
93
- #embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
94
-
95
- # ---------------------- Load QA Data ----------------------
96
  with open("qa.json", "r", encoding="utf-8") as f:
97
  qa_data = json.load(f)
98
 
@@ -100,7 +82,7 @@ qa_questions = list(qa_data.keys())
100
  qa_answers = list(qa_data.values())
101
  qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
102
 
103
- # ---------------------- CAG ----------------------
104
  def retrieve_from_cag(user_query):
105
  query_embedding = semantic_model.encode(user_query, convert_to_tensor=True)
106
  cosine_scores = util.cos_sim(query_embedding, qa_embeddings)[0]
@@ -109,21 +91,19 @@ def retrieve_from_cag(user_query):
109
 
110
  print(f"[CAG] Best score: {best_score:.4f} | Closest question: {qa_questions[best_idx]}")
111
  if best_score >= SIMILARITY_THRESHOLD:
112
- return qa_answers[best_idx], best_score
113
  else:
114
  return None, best_score
115
 
116
- # ---------------------- RAG ----------------------
117
- #client = chromadb.Client()
118
- #collection = client.get_collection(name="rag_web_db_cosine_full_documents")
119
- # Assuming you have a persistent Chroma client setup
120
- #client = PersistentClient("./db_new/db_new")# Replace with the correct path if needed
121
- #collection = client.get_collection(name="rag_web_db_cosine_full_documents")
122
- # ---------------------- RAG retrieval ----------------------
123
- def retrieve_from_rag(user_query):
124
- print("Searching in RAG...")
125
-
126
- query_embedding = embedding_model.encode(user_query)
127
  results = collection.query(query_embeddings=[query_embedding], n_results=3)
128
 
129
  if not results or not results.get('documents'):
@@ -151,11 +131,8 @@ Instructions:
151
  - Use only the provided documents below to answer.
152
  - If the answer is not in the documents, simply say: "I don't know." / "Je ne sais pas."
153
  - Cite only the sources you use, indicated at the end of each document like (Source: https://example.com).
154
-
155
-
156
  Documents :
157
  {context}
158
-
159
  Question : {query}
160
  Answer :
161
  [/INST]
@@ -171,60 +148,20 @@ Answer :
171
  print(f"Erreur lors de la génération : {e}")
172
  return "Erreur lors de la génération."
173
 
174
- # ---------------------- Generation function (Huggingface) ----------------------
175
- def generate_via_huggingface(context, query, max_new_tokens=512, hf_token="your_huggingface_token"):
176
- print("\n--- Generating via Huggingface ---")
177
- print("Context received:", context)
178
-
179
- prompt = f"""<s>[INST]
180
- You are a Moodle expert assistant.
181
-
182
- Rules:
183
- - Answer only based on the provided documents.
184
- - If the answer is not found, reply: "I don't know."
185
- - Only cite sources mentioned (metadata 'source').
186
-
187
- Documents:
188
- {context}
189
-
190
- Question: {query}
191
- Answer:
192
- [/INST]
193
- """
194
-
195
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
196
- headers = {"Authorization": f"Bearer {hf_token}"}
197
- payload = {
198
- "inputs": prompt,
199
- "parameters": {
200
- "max_new_tokens": max_new_tokens
201
- }
202
- }
203
-
204
- response = requests.post(API_URL, headers=headers, json=payload)
205
-
206
- if response.status_code == 200:
207
- result = response.json()
208
- if isinstance(result, list) and "generated_text" in result[0]:
209
- return result[0]["generated_text"].strip()
210
- else:
211
- return "Error: Unexpected response format."
212
- else:
213
- return f"Error {response.status_code}: {response.text}"
214
-
215
  # ---------------------- Main Chatbot ----------------------
216
- def chatbot(query):
217
  print("\n==== New Query ====")
218
  print("User Query:", query)
219
 
220
  # Try to retrieve from CAG (cache)
221
- answer, score = retrieve_from_cag(query)
222
  if answer:
223
  print("Answer retrieved from CAG cache.")
 
224
  return answer
225
 
226
  # If not found, retrieve from RAG
227
- docs = retrieve_from_rag(query)
228
  if docs:
229
  context_blocks = []
230
  for doc in docs:
@@ -241,47 +178,17 @@ def chatbot(query):
241
 
242
  context = "\n\n".join(context_blocks)
243
 
244
- # Choose the generation backend (OpenRouter or Huggingface)
245
  response = generate_via_openrouter(context, query)
 
246
  return response
247
 
248
  else:
249
  print("No relevant documents found.")
 
250
  return "Je ne sais pas."
251
 
252
-
253
  # ---------------------- Gradio App ----------------------
254
-
255
- # Define the chatbot response function
256
- #def ask(user_message, chat_history):
257
- # if not user_message:
258
- # return chat_history, chat_history, ""
259
- #
260
- # Get chatbot response
261
- # response = chatbot(user_message)
262
-
263
- # Update chat history
264
- # chat_history.append((user_message, response))
265
- # return chat_history, chat_history, ""
266
-
267
- # Initialize chat history with a welcome message
268
- #initial_message = (None, "Hello, how can I help you with Moodle?")
269
-
270
- # Build Gradio interface
271
- #with gr.Blocks(theme=gr.themes.Soft()) as demo:
272
- #chat_history = gr.State([initial_message]) # <-- Move inside here!
273
-
274
- # chatbot_ui = gr.Chatbot(value=[initial_message])
275
- # question = gr.Textbox(placeholder="Ask me anything about Moodle...", show_label=False)
276
- # clear_button = gr.Button("Clear")
277
-
278
- # question.submit(ask, [question, chat_history], [chatbot_ui, chat_history, question])
279
- # clear_button.click(lambda: ([initial_message], [initial_message], ""), None, [chatbot_ui, chat_history, question], queue=False)
280
-
281
- #demo.queue()
282
- #demo.launch(share=False)
283
- # Initialize chat history with a welcome message
284
-
285
  def save_chat_to_file(chat_history):
286
  timestamp = time.strftime("%Y%m%d-%H%M%S")
287
  filename = f"chat_history_{timestamp}.json"
@@ -294,17 +201,18 @@ def save_chat_to_file(chat_history):
294
  with open(file_path, "w", encoding="utf-8") as f:
295
  json.dump(chat_history, f, ensure_ascii=False, indent=2)
296
 
297
- return file_path # THIS should be only the path, not a tuple!
298
 
299
  def ask(user_message, chat_history):
300
  if not user_message:
301
  return chat_history, chat_history, ""
302
 
303
- response = chatbot(user_message)
304
  chat_history.append((user_message, response))
305
 
306
  return chat_history, chat_history, ""
307
 
 
308
  initial_message = (None, "Hello, how can I help you with Moodle?")
309
 
310
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -313,11 +221,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
313
  chatbot_ui = gr.Chatbot(value=[initial_message])
314
  question = gr.Textbox(placeholder="Ask me anything about Moodle...", show_label=False)
315
  clear_button = gr.Button("Clear")
316
- save_button = gr.Button("Save Chat")
317
 
318
  question.submit(ask, [question, chat_history], [chatbot_ui, chat_history, question])
319
  clear_button.click(lambda: ([initial_message], [initial_message], ""), None, [chatbot_ui, chat_history, question], queue=False)
320
-
321
  save_button.click(save_chat_to_file, [chat_history], gr.File(label="Download your chat history"))
322
 
323
  demo.queue()
 
1
+ import chromadb
2
  import pandas as pd
3
  from sentence_transformers import SentenceTransformer
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
8
  import numpy as np
9
  import requests
10
  import chromadb
11
+ from chromadb import Client
12
  from sentence_transformers import SentenceTransformer, util
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from chromadb import Client
 
21
  import time
22
  import tempfile
23
 
 
 
24
  API_KEY = os.environ.get("OPENROUTER_API_KEY")
25
 
 
26
  # Load the Excel file
27
  df = pd.read_excel("web_documents.xlsx", engine='openpyxl')
28
 
 
35
  metadata={"hnsw:space": "cosine"}
36
  )
37
 
38
+ # Load the embedding model
39
+ embedding_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
 
 
 
 
 
 
40
 
41
  # Initialize the text splitter
42
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=150)
 
69
 
70
  # ---------------------- Config ----------------------
71
  SIMILARITY_THRESHOLD = 0.80
72
+ client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Replace with your OpenRouter API key
 
 
73
 
74
  # ---------------------- Models ----------------------
75
+ semantic_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
 
 
 
 
76
 
77
+ # Load QA Data
 
 
 
78
  with open("qa.json", "r", encoding="utf-8") as f:
79
  qa_data = json.load(f)
80
 
 
82
  qa_answers = list(qa_data.values())
83
  qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
84
 
85
+ # ---------------------- History-Aware CAG ----------------------
86
  def retrieve_from_cag(user_query):
87
  query_embedding = semantic_model.encode(user_query, convert_to_tensor=True)
88
  cosine_scores = util.cos_sim(query_embedding, qa_embeddings)[0]
 
91
 
92
  print(f"[CAG] Best score: {best_score:.4f} | Closest question: {qa_questions[best_idx]}")
93
  if best_score >= SIMILARITY_THRESHOLD:
94
+ return qa_answers[best_idx], best_score # Only return the answer
95
  else:
96
  return None, best_score
97
 
98
+ # ---------------------- History-Aware RAG ----------------------
99
+ def retrieve_from_rag(user_query, chat_history):
100
+ # Combine the previous chat history with the current query for context
101
+ history_context = " ".join([f"User: {msg[0]} Bot: {msg[1]}" for msg in chat_history]) + " "
102
+ full_query = history_context + user_query
103
+
104
+ print("Searching in RAG with history context...")
105
+
106
+ query_embedding = embedding_model.encode(full_query)
 
 
107
  results = collection.query(query_embeddings=[query_embedding], n_results=3)
108
 
109
  if not results or not results.get('documents'):
 
131
  - Use only the provided documents below to answer.
132
  - If the answer is not in the documents, simply say: "I don't know." / "Je ne sais pas."
133
  - Cite only the sources you use, indicated at the end of each document like (Source: https://example.com).
 
 
134
  Documents :
135
  {context}
 
136
  Question : {query}
137
  Answer :
138
  [/INST]
 
148
  print(f"Erreur lors de la génération : {e}")
149
  return "Erreur lors de la génération."
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # ---------------------- Main Chatbot ----------------------
152
+ def chatbot(query, chat_history):
153
  print("\n==== New Query ====")
154
  print("User Query:", query)
155
 
156
  # Try to retrieve from CAG (cache)
157
+ answer, score = retrieve_from_cag(query, chat_history)
158
  if answer:
159
  print("Answer retrieved from CAG cache.")
160
+ chat_history.append((query, answer)) # Append the new question-answer pair to history
161
  return answer
162
 
163
  # If not found, retrieve from RAG
164
+ docs = retrieve_from_rag(query, chat_history)
165
  if docs:
166
  context_blocks = []
167
  for doc in docs:
 
178
 
179
  context = "\n\n".join(context_blocks)
180
 
181
+ # Choose the generation backend (OpenRouter)
182
  response = generate_via_openrouter(context, query)
183
+ chat_history.append((query, response)) # Append the new question-answer pair to history
184
  return response
185
 
186
  else:
187
  print("No relevant documents found.")
188
+ chat_history.append((query, "Je ne sais pas."))
189
  return "Je ne sais pas."
190
 
 
191
  # ---------------------- Gradio App ----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def save_chat_to_file(chat_history):
193
  timestamp = time.strftime("%Y%m%d-%H%M%S")
194
  filename = f"chat_history_{timestamp}.json"
 
201
  with open(file_path, "w", encoding="utf-8") as f:
202
  json.dump(chat_history, f, ensure_ascii=False, indent=2)
203
 
204
+ return file_path
205
 
206
  def ask(user_message, chat_history):
207
  if not user_message:
208
  return chat_history, chat_history, ""
209
 
210
+ response = chatbot(user_message, chat_history)
211
  chat_history.append((user_message, response))
212
 
213
  return chat_history, chat_history, ""
214
 
215
+ # Initialize chat history with a welcome message
216
  initial_message = (None, "Hello, how can I help you with Moodle?")
217
 
218
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
221
  chatbot_ui = gr.Chatbot(value=[initial_message])
222
  question = gr.Textbox(placeholder="Ask me anything about Moodle...", show_label=False)
223
  clear_button = gr.Button("Clear")
224
+ save_button = gr.Button("Save Chat")
225
 
226
  question.submit(ask, [question, chat_history], [chatbot_ui, chat_history, question])
227
  clear_button.click(lambda: ([initial_message], [initial_message], ""), None, [chatbot_ui, chat_history, question], queue=False)
228
+
229
  save_button.click(save_chat_to_file, [chat_history], gr.File(label="Download your chat history"))
230
 
231
  demo.queue()