dylanglenister commited on
Commit
47e3582
·
1 Parent(s): 5a8e374

REFACTOR: Improve chat pipeline.

Browse files

Refactoring the code pipeline for sending a message to prepare for implementing RAG.

src/api/routes/session.py CHANGED
@@ -1,14 +1,11 @@
1
  # src/api/routes/chat.py
2
 
3
- from datetime import datetime, timezone
4
-
5
  from fastapi import APIRouter, Depends, HTTPException, status
6
 
 
7
  from src.core.state import AppState, get_state
8
  from src.models.session import (ChatRequest, ChatResponse, Message, Session,
9
  SessionCreateRequest)
10
- from src.services.medical_response import generate_medical_response
11
- from src.services.guard import SafetyGuard
12
  from src.utils.logger import logger
13
 
14
  router = APIRouter(prefix="/session", tags=["Session & Chat"])
@@ -85,80 +82,21 @@ async def post_chat_message(
85
  and persists the full exchange to long-term memory.
86
  """
87
  logger().info(f"POST /session/{session_id}/messages")
88
-
89
- # 0. Safety Guard: Validate user query
90
- try:
91
- safety_guard = SafetyGuard(state.nvidia_rotator)
92
- is_safe, safety_reason = safety_guard.check_user_query(req.message)
93
- if not is_safe:
94
- logger().warning(f"Safety guard blocked user query: {safety_reason}")
95
- raise HTTPException(
96
- status_code=status.HTTP_400_BAD_REQUEST,
97
- detail=f"Query blocked for safety reasons: {safety_reason}"
98
- )
99
- logger().info(f"User query passed safety validation: {safety_reason}")
100
- except Exception as e:
101
- logger().error(f"Safety guard error: {e}")
102
- # Fail open for now - allow query through if guard fails
103
- logger().warning("Safety guard failed, allowing query through")
104
-
105
- # 1. Get Enhanced Context
106
  try:
107
- medical_context = await state.memory_manager.get_enhanced_context(
 
 
108
  session_id=session_id,
109
  patient_id=req.patient_id,
110
- question=req.message,
111
- nvidia_rotator=state.nvidia_rotator
112
  )
 
 
 
 
113
  except Exception as e:
114
- logger().error(f"Error getting medical context: {e}")
115
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to build medical context.")
116
-
117
- # 2. Generate AI Response
118
- try:
119
- # In a real app, user role/specialty would come from the authenticated user
120
- response_text = await generate_medical_response(
121
- user_message=req.message,
122
- user_role="Medical Professional",
123
- user_specialty="",
124
- rotator=state.gemini_rotator,
125
- medical_context=medical_context,
126
- nvidia_rotator=state.nvidia_rotator
127
  )
128
- except Exception as e:
129
- logger().error(f"Error generating medical response: {e}")
130
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate AI response.")
131
-
132
- # 2.5. Safety Guard: Validate AI response
133
- try:
134
- is_safe, safety_reason = safety_guard.check_model_answer(req.message, response_text)
135
- if not is_safe:
136
- logger().warning(f"Safety guard blocked AI response: {safety_reason}")
137
- # Replace with safe fallback response
138
- response_text = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
139
- else:
140
- logger().info(f"AI response passed safety validation: {safety_reason}")
141
- except Exception as e:
142
- logger().error(f"Safety guard error for response: {e}")
143
- # Fail open for now - allow response through if guard fails
144
- logger().warning("Safety guard failed for response, allowing through")
145
-
146
- # 3. Process and Store the Exchange
147
- summary = await state.memory_manager.process_medical_exchange(
148
- session_id=session_id,
149
- patient_id=req.patient_id,
150
- doctor_id=req.account_id,
151
- question=req.message,
152
- answer=response_text,
153
- gemini_rotator=state.gemini_rotator,
154
- nvidia_rotator=state.nvidia_rotator
155
- )
156
- if not summary:
157
- logger().warning(f"Failed to process and store medical exchange for session {session_id}")
158
-
159
- return ChatResponse(
160
- response=response_text,
161
- session_id=session_id,
162
- timestamp=datetime.now(timezone.utc),
163
- medical_context=medical_context
164
- )
 
1
  # src/api/routes/chat.py
2
 
 
 
3
  from fastapi import APIRouter, Depends, HTTPException, status
4
 
5
+ from src.core.response_pipeline import generate_chat_response
6
  from src.core.state import AppState, get_state
7
  from src.models.session import (ChatRequest, ChatResponse, Message, Session,
8
  SessionCreateRequest)
 
 
9
  from src.utils.logger import logger
10
 
11
  router = APIRouter(prefix="/session", tags=["Session & Chat"])
 
82
  and persists the full exchange to long-term memory.
83
  """
84
  logger().info(f"POST /session/{session_id}/messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  try:
86
+ response = await generate_chat_response(
87
+ state=state,
88
+ message=req.message,
89
  session_id=session_id,
90
  patient_id=req.patient_id,
91
+ account_id=req.account_id
 
92
  )
93
+ return ChatResponse(response=response)
94
+ except HTTPException as e:
95
+ # Re-raise HTTPException to let FastAPI handle it
96
+ raise e
97
  except Exception as e:
98
+ logger().error(f"Unhandled error in chat pipeline: {e}")
99
+ raise HTTPException(
100
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
101
+ detail="An unexpected error occurred."
 
 
 
 
 
 
 
 
 
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/core/prompt_builder.py CHANGED
@@ -2,15 +2,20 @@
2
 
3
  import json
4
 
 
5
  from src.models.medical import MedicalMemory
6
 
7
 
8
- def medical_response_prompt(user_role: str, user_specialty: str, medical_context: str, user_message: str) -> str:
 
 
 
 
9
  """Generates the prompt for creating a medical response."""
10
  return f"""You are a knowledgeable medical AI assistant. Provide a comprehensive, accurate, and helpful response to this medical question.
11
- **User Role:** {user_role}
12
- **User Specialty:** {user_specialty if user_specialty else 'General'}
13
- **Medical Context:** {medical_context if medical_context else 'No previous context'}
14
  **Question:** {user_message}
15
  **Instructions:**
16
  1. Provide a detailed, medically accurate response.
 
2
 
3
  import json
4
 
5
+ from src.models.account import Account
6
  from src.models.medical import MedicalMemory
7
 
8
 
9
+ def medical_response_prompt(
10
+ account: Account,
11
+ user_message: str,
12
+ medical_context: str | None = None
13
+ ) -> str:
14
  """Generates the prompt for creating a medical response."""
15
  return f"""You are a knowledgeable medical AI assistant. Provide a comprehensive, accurate, and helpful response to this medical question.
16
+ **User Role:** {account.role}
17
+ **User Specialty:** {account.specialty or "No specialty"}
18
+ **Medical Context:** {medical_context or 'No previous context'}
19
  **Question:** {user_message}
20
  **Instructions:**
21
  1. Provide a detailed, medically accurate response.
src/{services/medical_response.py → core/response_pipeline.py} RENAMED
@@ -1,23 +1,116 @@
1
- # src/services/medical_response.py
 
 
2
 
3
  from src.core import prompt_builder
 
4
  from src.data.medical_kb import search_medical_kb
 
5
  from src.services.gemini import gemini_chat
6
  from src.services.guard import SafetyGuard
7
  from src.utils.logger import logger
8
  from src.utils.rotator import APIKeyRotator
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  async def generate_medical_response(
12
- user_message: str,
13
- user_role: str,
14
- user_specialty: str,
15
  rotator: APIKeyRotator,
16
  medical_context: str = "",
17
- nvidia_rotator: APIKeyRotator = None
18
  ) -> str:
19
  """Generates an intelligent, contextual medical response using Gemini AI."""
20
- prompt = prompt_builder.medical_response_prompt(user_role, user_specialty, medical_context, user_message)
 
 
 
 
21
 
22
  # Generate response using Gemini
23
  response_text = await gemini_chat(prompt, rotator)
@@ -27,11 +120,12 @@ async def generate_medical_response(
27
  if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower():
28
  response_text += "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals."
29
 
 
30
  # Safety Guard: Validate the generated response
31
  if nvidia_rotator:
32
  try:
33
  safety_guard = SafetyGuard(nvidia_rotator)
34
- is_safe, safety_reason = safety_guard.check_model_answer(user_message, response_text)
35
  if not is_safe:
36
  logger().warning(f"Safety guard blocked generated response: {safety_reason}")
37
  # Return safe fallback response
@@ -47,16 +141,14 @@ async def generate_medical_response(
47
  return response_text
48
 
49
  logger().warning("Gemini response failed, using fallback.")
50
- return _generate_fallback_response(user_message, user_role, user_specialty)
51
 
52
  def _generate_fallback_response(
53
- user_message: str,
54
- user_role: str,
55
- user_specialty: str,
56
- medical_context: str = ""
57
  ) -> str:
58
  """Generates a fallback response using a local knowledge base."""
59
- kb_info = search_medical_kb(user_message)
60
 
61
  logger().info("Generating backup response")
62
 
@@ -64,7 +156,7 @@ def _generate_fallback_response(
64
  response_parts = []
65
 
66
  # Analyze the question to provide more specific responses
67
- question_lower = user_message.lower()
68
 
69
  if kb_info:
70
  response_parts.append(f"Based on your question about medical topics, here's what I found:\n\n{kb_info}")
@@ -117,21 +209,21 @@ def _generate_fallback_response(
117
  response_parts.append("Thank you for your medical question. While I can provide general information, it's important to consult with healthcare professionals for personalized medical advice.")
118
 
119
  # Add role-specific guidance
120
- if user_role.lower() in ["physician", "doctor", "nurse"]:
121
  response_parts.append("\n\n**Professional Context:** As a healthcare professional, you're likely familiar with these concepts. Remember to always follow your institution's protocols and guidelines, and consider the latest clinical evidence in your practice.")
122
- elif user_role.lower() in ["medical student", "student"]:
123
  response_parts.append("\n\n**Educational Context:** As a medical student, this information can help with your studies. Always verify information with your professors and clinical supervisors, and use this as a starting point for further research.")
124
- elif user_role.lower() in ["patient"]:
125
  response_parts.append("\n\n**Patient Context:** As a patient, this information is for educational purposes only. Please discuss any concerns with your healthcare provider, and don't make treatment decisions based solely on this information.")
126
  else:
127
  response_parts.append("\n\n**General Context:** This information is provided for educational purposes. Always consult with qualified healthcare professionals for medical advice.")
128
 
129
  # Add specialty-specific information if available
130
- if user_specialty and user_specialty.lower() in ["cardiology", "cardiac"]:
131
  response_parts.append("\n\n**Cardiology Perspective:** Given your interest in cardiology, consider how this information relates to cardiovascular health and patient care. Many conditions can have cardiac implications.")
132
- elif user_specialty and user_specialty.lower() in ["pediatrics", "pediatric"]:
133
  response_parts.append("\n\n**Pediatric Perspective:** In pediatric care, remember that children may present differently than adults and may require specialized approaches. Consider age-appropriate considerations.")
134
- elif user_specialty and user_specialty.lower() in ["emergency", "er"]:
135
  response_parts.append("\n\n**Emergency Medicine Perspective:** In emergency settings, rapid assessment and intervention are crucial. Consider the urgency and severity of presenting symptoms.")
136
 
137
  # Add medical disclaimer
 
1
+ # src/core/response_pipeline.py
2
+
3
+ from fastapi import HTTPException, status
4
 
5
  from src.core import prompt_builder
6
+ from src.core.state import AppState
7
  from src.data.medical_kb import search_medical_kb
8
+ from src.models.account import Account
9
  from src.services.gemini import gemini_chat
10
  from src.services.guard import SafetyGuard
11
  from src.utils.logger import logger
12
  from src.utils.rotator import APIKeyRotator
13
 
14
 
15
+ async def generate_chat_response(
16
+ state: AppState,
17
+ message: str,
18
+ session_id: str,
19
+ patient_id: str,
20
+ account_id: str
21
+ ) -> str:
22
+ """
23
+ Handles the pipeline for generating a chat response, including safety checks,
24
+ context retrieval, response generation, and memory persistence.
25
+ """
26
+ logger().info(f"Starting response pipeline for session {session_id}")
27
+
28
+ # 0. Safety Guard: Validate user query
29
+ try:
30
+ safety_guard = SafetyGuard(state.nvidia_rotator)
31
+ is_safe, safety_reason = safety_guard.check_user_query(message)
32
+ if not is_safe:
33
+ logger().warning(f"Safety guard blocked user query: {safety_reason}")
34
+ raise HTTPException(
35
+ status_code=status.HTTP_400_BAD_REQUEST,
36
+ detail=f"Query blocked for safety reasons: {safety_reason}"
37
+ )
38
+ logger().info(f"User query passed safety validation: {safety_reason}")
39
+ except Exception as e:
40
+ logger().error(f"Safety guard error: {e}")
41
+ raise e
42
+
43
+ # 1. Get Enhanced Context
44
+ # TODO Implement RAG
45
+ try:
46
+ medical_context = await state.memory_manager.get_enhanced_context(
47
+ session_id=session_id,
48
+ patient_id=patient_id,
49
+ question=message,
50
+ nvidia_rotator=state.nvidia_rotator
51
+ )
52
+ except Exception as e:
53
+ logger().error(f"Error getting medical context: {e}")
54
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to build medical context.")
55
+
56
+ account = state.memory_manager.get_account(account_id)
57
+ if not account:
58
+ raise Exception("Account not found")
59
+
60
+ # 2. Generate AI Response
61
+ try:
62
+ response_text = await generate_medical_response(
63
+ message=message,
64
+ account=account,
65
+ rotator=state.gemini_rotator,
66
+ medical_context=medical_context,
67
+ nvidia_rotator=state.nvidia_rotator
68
+ )
69
+ except Exception as e:
70
+ logger().error(f"Error generating medical response: {e}")
71
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate AI response.")
72
+
73
+ # TODO Safety guard is applied twice
74
+ # 2.5. Safety Guard: Validate AI response
75
+ try:
76
+ is_safe, safety_reason = safety_guard.check_model_answer(message, response_text) # type: ignore
77
+ if not is_safe:
78
+ logger().warning(f"Safety guard blocked AI response: {safety_reason}")
79
+ response_text = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
80
+ else:
81
+ logger().info(f"AI response passed safety validation: {safety_reason}")
82
+ except Exception as e:
83
+ logger().error(f"Safety guard error for response: {e}")
84
+ logger().warning("Safety guard failed for response, allowing through")
85
+
86
+ # 3. Process and Store the Exchange
87
+ summary = await state.memory_manager.process_medical_exchange(
88
+ session_id=session_id,
89
+ patient_id=patient_id,
90
+ doctor_id=account_id,
91
+ question=message,
92
+ answer=response_text,
93
+ gemini_rotator=state.gemini_rotator,
94
+ nvidia_rotator=state.nvidia_rotator
95
+ )
96
+ if not summary:
97
+ logger().warning(f"Failed to process and store medical exchange for session {session_id}")
98
+
99
+ return response_text
100
+
101
  async def generate_medical_response(
102
+ account: Account,
103
+ message: str,
 
104
  rotator: APIKeyRotator,
105
  medical_context: str = "",
106
+ nvidia_rotator: APIKeyRotator | None = None
107
  ) -> str:
108
  """Generates an intelligent, contextual medical response using Gemini AI."""
109
+ prompt = prompt_builder.medical_response_prompt(
110
+ account=account,
111
+ user_message=message,
112
+ medical_context=medical_context
113
+ )
114
 
115
  # Generate response using Gemini
116
  response_text = await gemini_chat(prompt, rotator)
 
120
  if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower():
121
  response_text += "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals."
122
 
123
+ # TODO Safety guard is applied to the response twice
124
  # Safety Guard: Validate the generated response
125
  if nvidia_rotator:
126
  try:
127
  safety_guard = SafetyGuard(nvidia_rotator)
128
+ is_safe, safety_reason = safety_guard.check_model_answer(message, response_text)
129
  if not is_safe:
130
  logger().warning(f"Safety guard blocked generated response: {safety_reason}")
131
  # Return safe fallback response
 
141
  return response_text
142
 
143
  logger().warning("Gemini response failed, using fallback.")
144
+ return _generate_fallback_response(message=message, account=account)
145
 
146
  def _generate_fallback_response(
147
+ message: str,
148
+ account: Account
 
 
149
  ) -> str:
150
  """Generates a fallback response using a local knowledge base."""
151
+ kb_info = search_medical_kb(message)
152
 
153
  logger().info("Generating backup response")
154
 
 
156
  response_parts = []
157
 
158
  # Analyze the question to provide more specific responses
159
+ question_lower = message.lower()
160
 
161
  if kb_info:
162
  response_parts.append(f"Based on your question about medical topics, here's what I found:\n\n{kb_info}")
 
209
  response_parts.append("Thank you for your medical question. While I can provide general information, it's important to consult with healthcare professionals for personalized medical advice.")
210
 
211
  # Add role-specific guidance
212
+ if account.role.lower() in ["physician", "doctor", "nurse"]:
213
  response_parts.append("\n\n**Professional Context:** As a healthcare professional, you're likely familiar with these concepts. Remember to always follow your institution's protocols and guidelines, and consider the latest clinical evidence in your practice.")
214
+ elif account.role.lower() in ["medical student", "student"]:
215
  response_parts.append("\n\n**Educational Context:** As a medical student, this information can help with your studies. Always verify information with your professors and clinical supervisors, and use this as a starting point for further research.")
216
+ elif account.role.lower() in ["patient"]:
217
  response_parts.append("\n\n**Patient Context:** As a patient, this information is for educational purposes only. Please discuss any concerns with your healthcare provider, and don't make treatment decisions based solely on this information.")
218
  else:
219
  response_parts.append("\n\n**General Context:** This information is provided for educational purposes. Always consult with qualified healthcare professionals for medical advice.")
220
 
221
  # Add specialty-specific information if available
222
+ if account.specialty and account.specialty.lower() in ["cardiology", "cardiac"]:
223
  response_parts.append("\n\n**Cardiology Perspective:** Given your interest in cardiology, consider how this information relates to cardiovascular health and patient care. Many conditions can have cardiac implications.")
224
+ elif account.specialty and account.specialty.lower() in ["pediatrics", "pediatric"]:
225
  response_parts.append("\n\n**Pediatric Perspective:** In pediatric care, remember that children may present differently than adults and may require specialized approaches. Consider age-appropriate considerations.")
226
+ elif account.specialty and account.specialty.lower() in ["emergency", "er"]:
227
  response_parts.append("\n\n**Emergency Medicine Perspective:** In emergency settings, rapid assessment and intervention are crucial. Consider the urgency and severity of presenting symptoms.")
228
 
229
  # Add medical disclaimer
src/models/session.py CHANGED
@@ -44,13 +44,9 @@ class ChatRequest(BaseModel):
44
  account_id: str # For context, though session_id implies this
45
  patient_id: str # For context, though session_id implies this
46
  message: str
47
- session_id: str | None = None # Optional session ID for continuing existing sessions
48
 
49
  # --- API Response Models ---
50
 
51
  class ChatResponse(BaseModel):
52
  """Response model for a chat interaction."""
53
  response: str
54
- session_id: str
55
- timestamp: datetime
56
- medical_context: str | None = None
 
44
  account_id: str # For context, though session_id implies this
45
  patient_id: str # For context, though session_id implies this
46
  message: str
 
47
 
48
  # --- API Response Models ---
49
 
50
  class ChatResponse(BaseModel):
51
  """Response model for a chat interaction."""
52
  response: str