dylanglenister commited on
Commit
24ed5c4
·
1 Parent(s): 84d39f9

CHORE: Memory and summariser housekeeping

Browse files
src/core/memory_manager.py CHANGED
@@ -177,9 +177,12 @@ class MemoryManager:
177
  session_repo.add_message(session_id, answer, sent_by_user=False)
178
 
179
  # 2. Generate a concise summary of the exchange
180
- summary = await self._generate_summary(question, answer, gemini_rotator, nvidia_rotator)
181
- if not summary:
182
- return None # Could not generate a summary
 
 
 
183
 
184
  # 3. Generate an embedding for the summary for semantic search
185
  embedding = None
@@ -199,7 +202,11 @@ class MemoryManager:
199
  )
200
 
201
  # 5. Update the session title if this was the first exchange
202
- await self._update_session_title_if_first_message(session_id, question, nvidia_rotator)
 
 
 
 
203
 
204
  return summary
205
  except ActionFailed as e:
@@ -228,7 +235,9 @@ class MemoryManager:
228
  if recent_memories:
229
  # Use NVIDIA to reason about relevance
230
  relevant_stm = await self._filter_summaries_for_relevance(
231
- question, [mem.summary for mem in recent_memories], nvidia_rotator
 
 
232
  )
233
  if relevant_stm:
234
  context_parts.append("Recent relevant medical context (STM):\n" + "\n".join(relevant_stm))
@@ -239,14 +248,21 @@ class MemoryManager:
239
  if self.embedder:
240
  try:
241
  query_embedding = self.embedder.embed([question])[0]
242
- ltm_results = memory_repo.search_memories_semantic(patient_id, query_embedding, limit=2)
 
 
 
 
243
  if ltm_results:
244
  ltm_summaries = [result.summary for result in ltm_results]
245
  context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries))
246
  except (ActionFailed, Exception) as e:
247
  logger().warning(f"Failed to perform LTM semantic search: {e}")
248
 
249
- # 3. Get current conversation context
 
 
 
250
  try:
251
  session = session_repo.get_session(session_id)
252
  if session and session.messages:
@@ -273,10 +289,10 @@ class MemoryManager:
273
  session = self.get_session(session_id)
274
  # Check if it's the first user message and first assistant response
275
  if session and len(session.messages) == 2:
276
- title = await summariser.summarise_title_with_nvidia(question, nvidia_rotator, max_words=5)
277
  if not title:
278
  title = question[:80] # Fallback to first 80 chars
279
- self.update_session_title(session_id, title)
280
  except Exception as e:
281
  logger().warning(f"Failed to auto-update session title for session '{session_id}': {e}")
282
 
@@ -289,18 +305,25 @@ class MemoryManager:
289
  ) -> str:
290
  """Generates a summary of a Q&A exchange, falling back to a basic format if AI fails."""
291
  try:
292
- summary = await summariser.summarise_qa_with_gemini(question, answer, gemini_rotator)
293
- if summary:
294
- return summary
 
 
 
 
295
  # Fallback to NVIDIA if Gemini fails
296
- summary = await summariser.summarise_qa_with_nvidia(question, answer, nvidia_rotator)
297
- if summary:
298
- return summary
 
 
 
299
  except Exception as e:
300
  logger().warning(f"Failed to generate AI summary: {e}")
301
 
302
  # Fallback for both exceptions and cases where services return None
303
- return f"Question: {question}\nAnswer: {answer}"
304
 
305
  async def _filter_summaries_for_relevance(
306
  self,
 
177
  session_repo.add_message(session_id, answer, sent_by_user=False)
178
 
179
  # 2. Generate a concise summary of the exchange
180
+ summary = await self._generate_summary(
181
+ question=question,
182
+ answer=answer,
183
+ gemini_rotator=gemini_rotator,
184
+ nvidia_rotator=nvidia_rotator
185
+ )
186
 
187
  # 3. Generate an embedding for the summary for semantic search
188
  embedding = None
 
202
  )
203
 
204
  # 5. Update the session title if this was the first exchange
205
+ await self._update_session_title_if_first_message(
206
+ session_id=session_id,
207
+ question=question,
208
+ nvidia_rotator=nvidia_rotator
209
+ )
210
 
211
  return summary
212
  except ActionFailed as e:
 
235
  if recent_memories:
236
  # Use NVIDIA to reason about relevance
237
  relevant_stm = await self._filter_summaries_for_relevance(
238
+ question=question,
239
+ summaries=[mem.summary for mem in recent_memories],
240
+ nvidia_rotator=nvidia_rotator
241
  )
242
  if relevant_stm:
243
  context_parts.append("Recent relevant medical context (STM):\n" + "\n".join(relevant_stm))
 
248
  if self.embedder:
249
  try:
250
  query_embedding = self.embedder.embed([question])[0]
251
+ ltm_results = memory_repo.search_memories_semantic(
252
+ patient_id=patient_id,
253
+ query_embedding=query_embedding,
254
+ limit=2
255
+ )
256
  if ltm_results:
257
  ltm_summaries = [result.summary for result in ltm_results]
258
  context_parts.append("Semantically relevant medical history (LTM):\n" + "\n".join(ltm_summaries))
259
  except (ActionFailed, Exception) as e:
260
  logger().warning(f"Failed to perform LTM semantic search: {e}")
261
 
262
+ # 3. Consult knowledge base
263
+ # TODO
264
+
265
+ # 4. Get current conversation context
266
  try:
267
  session = session_repo.get_session(session_id)
268
  if session and session.messages:
 
289
  session = self.get_session(session_id)
290
  # Check if it's the first user message and first assistant response
291
  if session and len(session.messages) == 2:
292
+ title = await summariser.summarise_title_with_nvidia(text=question, rotator=nvidia_rotator, max_words=5)
293
  if not title:
294
  title = question[:80] # Fallback to first 80 chars
295
+ self.update_session_title(session_id=session_id, title=title)
296
  except Exception as e:
297
  logger().warning(f"Failed to auto-update session title for session '{session_id}': {e}")
298
 
 
305
  ) -> str:
306
  """Generates a summary of a Q&A exchange, falling back to a basic format if AI fails."""
307
  try:
308
+ summary = await summariser.summarise_qa_with_gemini(
309
+ question=question,
310
+ answer=answer,
311
+ rotator=gemini_rotator
312
+ )
313
+ if summary: return summary
314
+
315
  # Fallback to NVIDIA if Gemini fails
316
+ summary = await summariser.summarise_qa_with_nvidia(
317
+ question=question,
318
+ answer=answer,
319
+ rotator=nvidia_rotator
320
+ )
321
+ if summary: return summary
322
  except Exception as e:
323
  logger().warning(f"Failed to generate AI summary: {e}")
324
 
325
  # Fallback for both exceptions and cases where services return None
326
+ return summariser.summarise_fallback(question=question, answer=answer)
327
 
328
  async def _filter_summaries_for_relevance(
329
  self,
src/services/summariser.py CHANGED
@@ -32,7 +32,7 @@ async def summarise_qa_with_gemini(
32
  question: str,
33
  answer: str,
34
  rotator: APIKeyRotator
35
- ) -> str:
36
  """Summarizes a Q&A pair into a 'q: ... a: ...' format using the Gemini API."""
37
  prompt = prompt_builder.qa_summary_gemini_prompt(question, answer)
38
  response = await gemini_chat(prompt, rotator)
@@ -45,14 +45,15 @@ async def summarise_qa_with_gemini(
45
  if q_line and a_line:
46
  return f"{q_line}\n{a_line}"
47
 
48
- logger().warning("Gemini summarization failed, using fallback.")
49
- return f"q: {question.strip()[:160]}\na: {answer.strip()[:220]}"
 
50
 
51
  async def summarise_qa_with_nvidia(
52
  question: str,
53
  answer: str,
54
  rotator: APIKeyRotator
55
- ) -> str:
56
  """Summarizes a Q&A pair into a 'q: ... a: ...' format using the NVIDIA API."""
57
  sys_prompt = "You are a terse summariser. Output exactly two lines:\nq: <short question summary>\na: <short answer summary>\nNo extra text."
58
  user_prompt = f"Question:\n{question}\n\nAnswer:\n{answer}"
@@ -65,6 +66,11 @@ async def summarise_qa_with_nvidia(
65
  if q_line and a_line:
66
  return f"{q_line}\n{a_line}"
67
 
68
- q_fallback = "q: " + (question.strip()[:160] + "…")
69
- a_fallback = "a: " + (answer.strip()[:220] + "…")
70
- return f"{q_fallback}\n{a_fallback}"
 
 
 
 
 
 
32
  question: str,
33
  answer: str,
34
  rotator: APIKeyRotator
35
+ ) -> str | None:
36
  """Summarizes a Q&A pair into a 'q: ... a: ...' format using the Gemini API."""
37
  prompt = prompt_builder.qa_summary_gemini_prompt(question, answer)
38
  response = await gemini_chat(prompt, rotator)
 
45
  if q_line and a_line:
46
  return f"{q_line}\n{a_line}"
47
 
48
+ #logger().warning("Gemini summarization failed, using fallback.")
49
+ #return f"q: {question.strip()[:160]}\na: {answer.strip()[:220]}"
50
+ return None
51
 
52
  async def summarise_qa_with_nvidia(
53
  question: str,
54
  answer: str,
55
  rotator: APIKeyRotator
56
+ ) -> str | None:
57
  """Summarizes a Q&A pair into a 'q: ... a: ...' format using the NVIDIA API."""
58
  sys_prompt = "You are a terse summariser. Output exactly two lines:\nq: <short question summary>\na: <short answer summary>\nNo extra text."
59
  user_prompt = f"Question:\n{question}\n\nAnswer:\n{answer}"
 
66
  if q_line and a_line:
67
  return f"{q_line}\n{a_line}"
68
 
69
+ #q_fallback = "q: " + (question.strip()[:160] + "…")
70
+ #a_fallback = "a: " + (answer.strip()[:220] + "…")
71
+ #return f"{q_fallback}\n{a_fallback}"
72
+
73
+ return None
74
+
75
+ def summarise_fallback(question: str, answer: str):
76
+ return f"q: {question.strip()[:160]}\na: {answer.strip()[:220]}"