dylanglenister commited on
Commit
d85186d
·
1 Parent(s): 833527f

FEAT: Consulting knowledge base for RAG.

Browse files
Files changed (1) hide show
  1. src/core/memory_manager.py +74 -20
src/core/memory_manager.py CHANGED
@@ -2,13 +2,14 @@
2
 
3
  from src.data.connection import ActionFailed
4
  from src.data.repositories import account as account_repo
 
5
  from src.data.repositories import medical_memory as memory_repo
6
  from src.data.repositories import patient as patient_repo
7
  from src.data.repositories import session as session_repo
8
  from src.models.account import Account
9
  from src.models.patient import Patient
10
  from src.models.session import Message, Session
11
- from src.services import summariser
12
  from src.services.nvidia import nvidia_chat
13
  from src.utils.embeddings import EmbeddingClient
14
  from src.utils.logger import logger
@@ -245,23 +246,28 @@ class MemoryManager:
245
  logger().warning(f"Could not retrieve recent memories for enhanced context: {e}")
246
 
247
  # 2. Get semantically similar summaries (Long-Term Memory)
248
- if self.embedder:
249
  try:
250
  query_embedding = self.embedder.embed([question])[0]
251
- ltm_results = memory_repo.search_memories_semantic(
252
- patient_id=patient_id,
253
- query_embedding=query_embedding,
254
- limit=2
255
- )
256
- if ltm_results:
257
- ltm_summaries = [result.summary for result in ltm_results]
258
- context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries))
 
259
  except (ActionFailed, Exception) as e:
260
  logger().warning(f"Failed to perform LTM semantic search: {e}")
261
 
262
  # 3. Consult knowledge base
263
- info = self._consult_knowledge_base(question=question)
264
- context_parts.append(info)
 
 
 
 
265
 
266
  # 4. Get current conversation context
267
  try:
@@ -275,17 +281,65 @@ class MemoryManager:
275
  except ActionFailed as e:
276
  logger().warning(f"Could not retrieve current session context: {e}")
277
 
278
- return "\n\n".join(context_parts)
279
 
280
  # --- Private Helper Methods ---
281
 
282
- def _consult_knowledge_base(self, question: str) -> str:
283
- # 1. Embedding
284
- # Cannot use src/utils/embeddings.py because it uses sentence transformers while the knowledge base uses torch for embedding.
285
- # 2. Query
286
- # Use src/data/repositories/information.py to access the knowledge base stored on mongodb.
287
- # 3. Reponse
288
- # The result will need to be semanticly ranked. Suggested: https://build.nvidia.com/nvidia/rerank-qa-mistral-4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  return ""
290
 
291
  async def _update_session_title_if_first_message(
 
2
 
3
  from src.data.connection import ActionFailed
4
  from src.data.repositories import account as account_repo
5
+ from src.data.repositories import information as info_repo
6
  from src.data.repositories import medical_memory as memory_repo
7
  from src.data.repositories import patient as patient_repo
8
  from src.data.repositories import session as session_repo
9
  from src.models.account import Account
10
  from src.models.patient import Patient
11
  from src.models.session import Message, Session
12
+ from src.services import reranker, summariser
13
  from src.services.nvidia import nvidia_chat
14
  from src.utils.embeddings import EmbeddingClient
15
  from src.utils.logger import logger
 
246
  logger().warning(f"Could not retrieve recent memories for enhanced context: {e}")
247
 
248
  # 2. Get semantically similar summaries (Long-Term Memory)
249
+ if self.embedder and self.embedder.is_available():
250
  try:
251
  query_embedding = self.embedder.embed([question])[0]
252
+ if query_embedding:
253
+ ltm_results = memory_repo.search_memories_semantic(
254
+ patient_id=patient_id,
255
+ query_embedding=query_embedding,
256
+ limit=2
257
+ )
258
+ if ltm_results:
259
+ ltm_summaries = [result.summary for result in ltm_results]
260
+ context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries))
261
  except (ActionFailed, Exception) as e:
262
  logger().warning(f"Failed to perform LTM semantic search: {e}")
263
 
264
  # 3. Consult knowledge base
265
+ info = await self._consult_knowledge_base(
266
+ question=question,
267
+ nvidia_rotator=nvidia_rotator
268
+ )
269
+ if info:
270
+ context_parts.append(info)
271
 
272
  # 4. Get current conversation context
273
  try:
 
281
  except ActionFailed as e:
282
  logger().warning(f"Could not retrieve current session context: {e}")
283
 
284
+ return "\n\n".join(filter(None, context_parts))
285
 
286
  # --- Private Helper Methods ---
287
 
288
+ async def _consult_knowledge_base(
289
+ self,
290
+ question: str,
291
+ nvidia_rotator: APIKeyRotator
292
+ ) -> str:
293
+ """
294
+ Embeds a question, queries the knowledge base for relevant chunks,
295
+ reranks them, and formats them into a context string.
296
+ """
297
+ if not self.embedder or not self.embedder.is_available():
298
+ logger().warning("Embedder not available, skipping knowledge base consultation.")
299
+ return ""
300
+
301
+ try:
302
+ # 1. Embed the user's question
303
+ query_embedding = self.embedder.embed([question])[0]
304
+ if not query_embedding:
305
+ logger().warning("Failed to generate query embedding.")
306
+ return ""
307
+
308
+ # 2. Retrieve initial candidates from MongoDB
309
+ initial_chunks = info_repo.search_chunks_semantic(
310
+ query_embedding=query_embedding,
311
+ limit=10 # Retrieve more candidates for the reranker to process
312
+ )
313
+ if not initial_chunks:
314
+ logger().info("No relevant chunks found in the knowledge base.")
315
+ return ""
316
+
317
+ # 3. Rerank the results for semantic relevance
318
+ reranked_chunks = await reranker.rerank_documents(
319
+ query=question,
320
+ documents=initial_chunks,
321
+ rotator=nvidia_rotator,
322
+ top_k=3 # Keep the top 3 most relevant results
323
+ )
324
+ if not reranked_chunks:
325
+ logger().warning("Reranking failed to return any chunks.")
326
+ return ""
327
+
328
+ # 4. Format the final response
329
+ context_header = "Consulted Knowledge Base for context:"
330
+ formatted_chunks = []
331
+ for chunk in reranked_chunks:
332
+ source = chunk.metadata.source
333
+ content = chunk.content.strip()
334
+ formatted_chunks.append(f"[Source: {source}]\n{content}")
335
+
336
+ return f"{context_header}\n\n" + "\n\n".join(formatted_chunks)
337
+
338
+ except ActionFailed as e:
339
+ logger().error(f"A database error occurred while consulting the knowledge base: {e}")
340
+ except Exception as e:
341
+ logger().error(f"An unexpected error occurred during knowledge base consultation: {e}")
342
+
343
  return ""
344
 
345
  async def _update_session_title_if_first_message(