dylanglenister commited on
Commit
84d39f9
·
1 Parent(s): 6d1027d

REFACTOR: Refactoring chat pipeline.

Browse files

Using functional decomposition to make working on this file easier in the future.

Files changed (1) hide show
  1. src/core/response_pipeline.py +153 -90
src/core/response_pipeline.py CHANGED
@@ -11,39 +11,64 @@ from src.services.guard import SafetyGuard
11
  from src.utils.logger import logger
12
  from src.utils.rotator import APIKeyRotator
13
 
 
14
 
15
- async def generate_chat_response(
16
- state: AppState,
17
- message: str,
18
- session_id: str,
19
- patient_id: str,
20
- account_id: str
21
- ) -> str:
22
  """
23
- Handles the pipeline for generating a chat response, including safety checks,
24
- context retrieval, response generation, and memory persistence.
25
  """
26
- logger().info(f"Starting response pipeline for session {session_id}")
27
-
28
- # 0. Safety Guard: Validate user query
29
  try:
30
- safety_guard = SafetyGuard(state.nvidia_rotator)
31
- is_safe, safety_reason = safety_guard.check_user_query(message)
32
  if not is_safe:
33
- logger().warning(f"Safety guard blocked user query: {safety_reason}")
34
  raise HTTPException(
35
  status_code=status.HTTP_400_BAD_REQUEST,
36
- detail=f"Query blocked for safety reasons: {safety_reason}"
37
  )
38
- logger().info(f"User query passed safety validation: {safety_reason}")
39
  except Exception as e:
40
- logger().error(f"Safety guard error: {e}")
41
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # 1. Get Enhanced Context
44
- # TODO Implement RAG
 
 
 
45
  try:
46
- medical_context = await state.memory_manager.get_enhanced_context(
47
  session_id=session_id,
48
  patient_id=patient_id,
49
  question=message,
@@ -51,97 +76,135 @@ async def generate_chat_response(
51
  )
52
  except Exception as e:
53
  logger().error(f"Error getting medical context: {e}")
54
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to build medical context.")
55
-
56
- account = state.memory_manager.get_account(account_id)
57
- if not account:
58
- raise Exception("Account not found")
59
-
60
- # 2. Generate AI Response
61
- try:
62
- response_text = await generate_medical_response(
63
- message=message,
64
- account=account,
65
- rotator=state.gemini_rotator,
66
- medical_context=medical_context,
67
- nvidia_rotator=state.nvidia_rotator
68
- )
69
- except Exception as e:
70
- logger().error(f"Error generating medical response: {e}")
71
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate AI response.")
72
-
73
- # TODO Safety guard is applied twice
74
- # 2.5. Safety Guard: Validate AI response
75
- try:
76
- is_safe, safety_reason = safety_guard.check_model_answer(message, response_text) # type: ignore
77
- if not is_safe:
78
- logger().warning(f"Safety guard blocked AI response: {safety_reason}")
79
- response_text = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
80
- else:
81
- logger().info(f"AI response passed safety validation: {safety_reason}")
82
- except Exception as e:
83
- logger().error(f"Safety guard error for response: {e}")
84
- logger().warning("Safety guard failed for response, allowing through")
85
 
86
- # 3. Process and Store the Exchange
 
 
 
 
 
 
 
 
87
  summary = await state.memory_manager.process_medical_exchange(
88
  session_id=session_id,
89
  patient_id=patient_id,
90
  doctor_id=account_id,
91
- question=message,
92
- answer=response_text,
93
  gemini_rotator=state.gemini_rotator,
94
  nvidia_rotator=state.nvidia_rotator
95
  )
96
  if not summary:
97
  logger().warning(f"Failed to process and store medical exchange for session {session_id}")
98
 
99
- return response_text
100
 
101
- async def generate_medical_response(
 
 
102
  account: Account,
103
  message: str,
104
  rotator: APIKeyRotator,
105
- medical_context: str = "",
106
- nvidia_rotator: APIKeyRotator | None = None
107
- ) -> str:
108
- """Generates an intelligent, contextual medical response using Gemini AI."""
 
 
109
  prompt = prompt_builder.medical_response_prompt(
110
  account=account,
111
  user_message=message,
112
  medical_context=medical_context
113
  )
114
 
115
- # Generate response using Gemini
116
  response_text = await gemini_chat(prompt, rotator)
117
 
118
- if response_text:
119
- # Add medical disclaimer if not already present
120
- if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower():
121
- response_text += "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals."
122
-
123
- # TODO Safety guard is applied to the response twice
124
- # Safety Guard: Validate the generated response
125
- if nvidia_rotator:
126
- try:
127
- safety_guard = SafetyGuard(nvidia_rotator)
128
- is_safe, safety_reason = safety_guard.check_model_answer(message, response_text)
129
- if not is_safe:
130
- logger().warning(f"Safety guard blocked generated response: {safety_reason}")
131
- # Return safe fallback response
132
- return "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
133
- else:
134
- logger().info(f"Generated response passed safety validation: {safety_reason}")
135
- except Exception as e:
136
- logger().error(f"Safety guard error in medical response: {e}")
137
- # Fail open for now - allow response through if guard fails
138
- logger().warning("Safety guard failed, allowing generated response through")
139
-
140
- logger().info(f"Gemini response generated, length: {len(response_text)} chars")
141
- return response_text
142
-
143
- logger().warning("Gemini response failed, using fallback.")
144
- return _generate_fallback_response(message=message, account=account)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def _generate_fallback_response(
147
  message: str,
 
11
  from src.utils.logger import logger
12
  from src.utils.rotator import APIKeyRotator
13
 
14
+ # --- Private Helper Functions ---
15
 
16
+ def _validate_user_query(message: str, safety_guard: SafetyGuard | None):
 
 
 
 
 
 
17
  """
18
+ Checks the user's query against the safety guard.
19
+ Raises an HTTPException if the query is unsafe.
20
  """
21
+ if not safety_guard: return
 
 
22
  try:
23
+ is_safe, reason = safety_guard.check_user_query(message)
 
24
  if not is_safe:
25
+ logger().warning(f"Safety guard blocked user query: {reason}")
26
  raise HTTPException(
27
  status_code=status.HTTP_400_BAD_REQUEST,
28
+ detail=f"Query blocked for safety reasons: {reason}"
29
  )
30
+ logger().info(f"User query passed safety validation: {reason}")
31
  except Exception as e:
32
+ logger().error(f"Safety guard failed on user query: {e}")
33
+ # Re-raise to be caught by the main orchestrator
34
+ raise HTTPException(
35
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
36
+ detail="Failed to validate user query safety."
37
+ ) from e
38
+
39
+ def _validate_model_response(query: str, response: str, safety_guard: SafetyGuard | None) -> str:
40
+ """
41
+ Checks the generated model response against the safety guard.
42
+ Returns a safe fallback message if the response is deemed unsafe.
43
+ """
44
+ if not safety_guard: return response
45
+ safe_fallback = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
46
+ try:
47
+ is_safe, reason = safety_guard.check_model_answer(query, response)
48
+ if not is_safe:
49
+ logger().warning(f"Safety guard blocked AI response: {reason}")
50
+ return safe_fallback
51
+ logger().info(f"AI response passed safety validation: {reason}")
52
+ return response
53
+ except Exception as e:
54
+ logger().error(f"Safety guard failed on model response: {e}")
55
+ logger().warning("Safety guard failed, allowing response through (fail-open)")
56
+ # Fail open: return the original response if the guard itself fails
57
+ return response
58
+
59
+ async def _retrieve_context(
60
+ state: AppState, session_id: str, patient_id: str, message: str
61
+ ) -> str:
62
+ """
63
+ Retrieves enhanced medical context. This is the entry point for RAG.
64
 
65
+ Future RAG Implementation:
66
+ 1. Augment this function to query a vector database or knowledge base.
67
+ 2. Combine the results with the existing memory manager context.
68
+ 3. Return the consolidated context string.
69
+ """
70
  try:
71
+ return await state.memory_manager.get_enhanced_context(
72
  session_id=session_id,
73
  patient_id=patient_id,
74
  question=message,
 
76
  )
77
  except Exception as e:
78
  logger().error(f"Error getting medical context: {e}")
79
+ raise HTTPException(
80
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
81
+ detail="Failed to build medical context."
82
+ ) from e
83
+
84
+ def _add_disclaimer(response_text: str) -> str:
85
+ """Adds a standard medical disclaimer if one is not already present."""
86
+ if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower():
87
+ disclaimer = "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals."
88
+ return response_text + disclaimer
89
+ return response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ async def _persist_exchange(
92
+ state: AppState,
93
+ session_id: str,
94
+ patient_id: str,
95
+ account_id: str,
96
+ question: str,
97
+ answer: str
98
+ ):
99
+ """Processes and stores the full conversation exchange."""
100
  summary = await state.memory_manager.process_medical_exchange(
101
  session_id=session_id,
102
  patient_id=patient_id,
103
  doctor_id=account_id,
104
+ question=question,
105
+ answer=answer,
106
  gemini_rotator=state.gemini_rotator,
107
  nvidia_rotator=state.nvidia_rotator
108
  )
109
  if not summary:
110
  logger().warning(f"Failed to process and store medical exchange for session {session_id}")
111
 
 
112
 
113
+ # --- Core Response Generation Logic ---
114
+
115
+ async def generate_llm_response(
116
  account: Account,
117
  message: str,
118
  rotator: APIKeyRotator,
119
+ medical_context: str = ""
120
+ ) -> str | None:
121
+ """
122
+ Generates an intelligent medical response using the LLM, adding a disclaimer.
123
+ This function is now purely for generation, with safety checks handled elsewhere.
124
+ """
125
  prompt = prompt_builder.medical_response_prompt(
126
  account=account,
127
  user_message=message,
128
  medical_context=medical_context
129
  )
130
 
 
131
  response_text = await gemini_chat(prompt, rotator)
132
 
133
+ if not response_text:
134
+ return None
135
+
136
+ response_with_disclaimer = _add_disclaimer(response_text)
137
+ logger().info(f"Gemini response generated, length: {len(response_with_disclaimer)} chars")
138
+ return response_with_disclaimer
139
+
140
+
141
+ # --- Main Pipeline Orchestrator ---
142
+
143
+ async def generate_chat_response(
144
+ state: AppState,
145
+ message: str,
146
+ session_id: str,
147
+ patient_id: str,
148
+ account_id: str
149
+ ) -> str:
150
+ """
151
+ Orchestrates the pipeline for generating a chat response.
152
+ """
153
+ logger().info(f"Starting response pipeline for session {session_id}")
154
+ safety_guard: SafetyGuard | None = None
155
+
156
+ try:
157
+ safety_guard = SafetyGuard(state.nvidia_rotator)
158
+ except Exception as e:
159
+ logger().warning("Safety guard failed to be created, ignoring")
160
+
161
+ # 1. Validate User Query
162
+ _validate_user_query(message, safety_guard)
163
+
164
+ # 2. Retrieve Context (RAG Entry Point)
165
+ medical_context = await _retrieve_context(state, session_id, patient_id, message)
166
+
167
+ # 3. Fetch Account Details
168
+ account = state.memory_manager.get_account(account_id)
169
+ if not account:
170
+ logger().error(f"Account not found for account_id: {account_id}")
171
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
172
+
173
+ # 4. Generate AI Response
174
+ try:
175
+ response_text = await generate_llm_response(
176
+ message=message,
177
+ account=account,
178
+ rotator=state.gemini_rotator,
179
+ medical_context=medical_context
180
+ )
181
+ # If LLM fails, use a fallback
182
+ if not response_text:
183
+ logger().warning("LLM response failed, using fallback.")
184
+ response_text = _generate_fallback_response(message=message, account=account)
185
+
186
+ except Exception as e:
187
+ logger().error(f"Error generating medical response: {e}")
188
+ raise HTTPException(
189
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
+ detail="Failed to generate AI response."
191
+ ) from e
192
+
193
+ # 5. Validate Model's Response
194
+ final_response = _validate_model_response(message, response_text, safety_guard)
195
+
196
+ # 6. Persist the Exchange (Asynchronously)
197
+ # This can be done in the background if it's not critical for the user response
198
+ await _persist_exchange(
199
+ state=state,
200
+ session_id=session_id,
201
+ patient_id=patient_id,
202
+ account_id=account_id,
203
+ question=message,
204
+ answer=final_response
205
+ )
206
+
207
+ return final_response
208
 
209
  def _generate_fallback_response(
210
  message: str,