Spaces:
Sleeping
Sleeping
dylanglenister
commited on
Commit
·
84d39f9
1
Parent(s):
6d1027d
REFACTOR: Refactoring chat pipeline.
Browse filesUsing functional decomposition to make working on this file easier in the future.
- src/core/response_pipeline.py +153 -90
src/core/response_pipeline.py
CHANGED
|
@@ -11,39 +11,64 @@ from src.services.guard import SafetyGuard
|
|
| 11 |
from src.utils.logger import logger
|
| 12 |
from src.utils.rotator import APIKeyRotator
|
| 13 |
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
state: AppState,
|
| 17 |
-
message: str,
|
| 18 |
-
session_id: str,
|
| 19 |
-
patient_id: str,
|
| 20 |
-
account_id: str
|
| 21 |
-
) -> str:
|
| 22 |
"""
|
| 23 |
-
|
| 24 |
-
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# 0. Safety Guard: Validate user query
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
is_safe, safety_reason = safety_guard.check_user_query(message)
|
| 32 |
if not is_safe:
|
| 33 |
-
logger().warning(f"Safety guard blocked user query: {
|
| 34 |
raise HTTPException(
|
| 35 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 36 |
-
detail=f"Query blocked for safety reasons: {
|
| 37 |
)
|
| 38 |
-
logger().info(f"User query passed safety validation: {
|
| 39 |
except Exception as e:
|
| 40 |
-
logger().error(f"Safety guard
|
| 41 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
try:
|
| 46 |
-
|
| 47 |
session_id=session_id,
|
| 48 |
patient_id=patient_id,
|
| 49 |
question=message,
|
|
@@ -51,97 +76,135 @@ async def generate_chat_response(
|
|
| 51 |
)
|
| 52 |
except Exception as e:
|
| 53 |
logger().error(f"Error getting medical context: {e}")
|
| 54 |
-
raise HTTPException(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
rotator=state.gemini_rotator,
|
| 66 |
-
medical_context=medical_context,
|
| 67 |
-
nvidia_rotator=state.nvidia_rotator
|
| 68 |
-
)
|
| 69 |
-
except Exception as e:
|
| 70 |
-
logger().error(f"Error generating medical response: {e}")
|
| 71 |
-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate AI response.")
|
| 72 |
-
|
| 73 |
-
# TODO Safety guard is applied twice
|
| 74 |
-
# 2.5. Safety Guard: Validate AI response
|
| 75 |
-
try:
|
| 76 |
-
is_safe, safety_reason = safety_guard.check_model_answer(message, response_text) # type: ignore
|
| 77 |
-
if not is_safe:
|
| 78 |
-
logger().warning(f"Safety guard blocked AI response: {safety_reason}")
|
| 79 |
-
response_text = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
|
| 80 |
-
else:
|
| 81 |
-
logger().info(f"AI response passed safety validation: {safety_reason}")
|
| 82 |
-
except Exception as e:
|
| 83 |
-
logger().error(f"Safety guard error for response: {e}")
|
| 84 |
-
logger().warning("Safety guard failed for response, allowing through")
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
summary = await state.memory_manager.process_medical_exchange(
|
| 88 |
session_id=session_id,
|
| 89 |
patient_id=patient_id,
|
| 90 |
doctor_id=account_id,
|
| 91 |
-
question=
|
| 92 |
-
answer=
|
| 93 |
gemini_rotator=state.gemini_rotator,
|
| 94 |
nvidia_rotator=state.nvidia_rotator
|
| 95 |
)
|
| 96 |
if not summary:
|
| 97 |
logger().warning(f"Failed to process and store medical exchange for session {session_id}")
|
| 98 |
|
| 99 |
-
return response_text
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
account: Account,
|
| 103 |
message: str,
|
| 104 |
rotator: APIKeyRotator,
|
| 105 |
-
medical_context: str = ""
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
prompt = prompt_builder.medical_response_prompt(
|
| 110 |
account=account,
|
| 111 |
user_message=message,
|
| 112 |
medical_context=medical_context
|
| 113 |
)
|
| 114 |
|
| 115 |
-
# Generate response using Gemini
|
| 116 |
response_text = await gemini_chat(prompt, rotator)
|
| 117 |
|
| 118 |
-
if response_text:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def _generate_fallback_response(
|
| 147 |
message: str,
|
|
|
|
| 11 |
from src.utils.logger import logger
|
| 12 |
from src.utils.rotator import APIKeyRotator
|
| 13 |
|
| 14 |
+
# --- Private Helper Functions ---
|
| 15 |
|
| 16 |
+
def _validate_user_query(message: str, safety_guard: SafetyGuard | None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
+
Checks the user's query against the safety guard.
|
| 19 |
+
Raises an HTTPException if the query is unsafe.
|
| 20 |
"""
|
| 21 |
+
if not safety_guard: return
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
+
is_safe, reason = safety_guard.check_user_query(message)
|
|
|
|
| 24 |
if not is_safe:
|
| 25 |
+
logger().warning(f"Safety guard blocked user query: {reason}")
|
| 26 |
raise HTTPException(
|
| 27 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 28 |
+
detail=f"Query blocked for safety reasons: {reason}"
|
| 29 |
)
|
| 30 |
+
logger().info(f"User query passed safety validation: {reason}")
|
| 31 |
except Exception as e:
|
| 32 |
+
logger().error(f"Safety guard failed on user query: {e}")
|
| 33 |
+
# Re-raise to be caught by the main orchestrator
|
| 34 |
+
raise HTTPException(
|
| 35 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 36 |
+
detail="Failed to validate user query safety."
|
| 37 |
+
) from e
|
| 38 |
+
|
| 39 |
+
def _validate_model_response(query: str, response: str, safety_guard: SafetyGuard | None) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Checks the generated model response against the safety guard.
|
| 42 |
+
Returns a safe fallback message if the response is deemed unsafe.
|
| 43 |
+
"""
|
| 44 |
+
if not safety_guard: return response
|
| 45 |
+
safe_fallback = "I apologize, but I cannot provide a response to that query as it may contain unsafe content. Please consult with a qualified healthcare professional for medical advice."
|
| 46 |
+
try:
|
| 47 |
+
is_safe, reason = safety_guard.check_model_answer(query, response)
|
| 48 |
+
if not is_safe:
|
| 49 |
+
logger().warning(f"Safety guard blocked AI response: {reason}")
|
| 50 |
+
return safe_fallback
|
| 51 |
+
logger().info(f"AI response passed safety validation: {reason}")
|
| 52 |
+
return response
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger().error(f"Safety guard failed on model response: {e}")
|
| 55 |
+
logger().warning("Safety guard failed, allowing response through (fail-open)")
|
| 56 |
+
# Fail open: return the original response if the guard itself fails
|
| 57 |
+
return response
|
| 58 |
+
|
| 59 |
+
async def _retrieve_context(
|
| 60 |
+
state: AppState, session_id: str, patient_id: str, message: str
|
| 61 |
+
) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Retrieves enhanced medical context. This is the entry point for RAG.
|
| 64 |
|
| 65 |
+
Future RAG Implementation:
|
| 66 |
+
1. Augment this function to query a vector database or knowledge base.
|
| 67 |
+
2. Combine the results with the existing memory manager context.
|
| 68 |
+
3. Return the consolidated context string.
|
| 69 |
+
"""
|
| 70 |
try:
|
| 71 |
+
return await state.memory_manager.get_enhanced_context(
|
| 72 |
session_id=session_id,
|
| 73 |
patient_id=patient_id,
|
| 74 |
question=message,
|
|
|
|
| 76 |
)
|
| 77 |
except Exception as e:
|
| 78 |
logger().error(f"Error getting medical context: {e}")
|
| 79 |
+
raise HTTPException(
|
| 80 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 81 |
+
detail="Failed to build medical context."
|
| 82 |
+
) from e
|
| 83 |
+
|
| 84 |
+
def _add_disclaimer(response_text: str) -> str:
|
| 85 |
+
"""Adds a standard medical disclaimer if one is not already present."""
|
| 86 |
+
if "disclaimer" not in response_text.lower() and "consult" not in response_text.lower():
|
| 87 |
+
disclaimer = "\n\n⚠️ **Important Disclaimer:** This information is for educational purposes only and should not replace professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare professionals."
|
| 88 |
+
return response_text + disclaimer
|
| 89 |
+
return response_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
async def _persist_exchange(
|
| 92 |
+
state: AppState,
|
| 93 |
+
session_id: str,
|
| 94 |
+
patient_id: str,
|
| 95 |
+
account_id: str,
|
| 96 |
+
question: str,
|
| 97 |
+
answer: str
|
| 98 |
+
):
|
| 99 |
+
"""Processes and stores the full conversation exchange."""
|
| 100 |
summary = await state.memory_manager.process_medical_exchange(
|
| 101 |
session_id=session_id,
|
| 102 |
patient_id=patient_id,
|
| 103 |
doctor_id=account_id,
|
| 104 |
+
question=question,
|
| 105 |
+
answer=answer,
|
| 106 |
gemini_rotator=state.gemini_rotator,
|
| 107 |
nvidia_rotator=state.nvidia_rotator
|
| 108 |
)
|
| 109 |
if not summary:
|
| 110 |
logger().warning(f"Failed to process and store medical exchange for session {session_id}")
|
| 111 |
|
|
|
|
| 112 |
|
| 113 |
+
# --- Core Response Generation Logic ---
|
| 114 |
+
|
| 115 |
+
async def generate_llm_response(
|
| 116 |
account: Account,
|
| 117 |
message: str,
|
| 118 |
rotator: APIKeyRotator,
|
| 119 |
+
medical_context: str = ""
|
| 120 |
+
) -> str | None:
|
| 121 |
+
"""
|
| 122 |
+
Generates an intelligent medical response using the LLM, adding a disclaimer.
|
| 123 |
+
This function is now purely for generation, with safety checks handled elsewhere.
|
| 124 |
+
"""
|
| 125 |
prompt = prompt_builder.medical_response_prompt(
|
| 126 |
account=account,
|
| 127 |
user_message=message,
|
| 128 |
medical_context=medical_context
|
| 129 |
)
|
| 130 |
|
|
|
|
| 131 |
response_text = await gemini_chat(prompt, rotator)
|
| 132 |
|
| 133 |
+
if not response_text:
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
response_with_disclaimer = _add_disclaimer(response_text)
|
| 137 |
+
logger().info(f"Gemini response generated, length: {len(response_with_disclaimer)} chars")
|
| 138 |
+
return response_with_disclaimer
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# --- Main Pipeline Orchestrator ---
|
| 142 |
+
|
| 143 |
+
async def generate_chat_response(
|
| 144 |
+
state: AppState,
|
| 145 |
+
message: str,
|
| 146 |
+
session_id: str,
|
| 147 |
+
patient_id: str,
|
| 148 |
+
account_id: str
|
| 149 |
+
) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Orchestrates the pipeline for generating a chat response.
|
| 152 |
+
"""
|
| 153 |
+
logger().info(f"Starting response pipeline for session {session_id}")
|
| 154 |
+
safety_guard: SafetyGuard | None = None
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
safety_guard = SafetyGuard(state.nvidia_rotator)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger().warning("Safety guard failed to be created, ignoring")
|
| 160 |
+
|
| 161 |
+
# 1. Validate User Query
|
| 162 |
+
_validate_user_query(message, safety_guard)
|
| 163 |
+
|
| 164 |
+
# 2. Retrieve Context (RAG Entry Point)
|
| 165 |
+
medical_context = await _retrieve_context(state, session_id, patient_id, message)
|
| 166 |
+
|
| 167 |
+
# 3. Fetch Account Details
|
| 168 |
+
account = state.memory_manager.get_account(account_id)
|
| 169 |
+
if not account:
|
| 170 |
+
logger().error(f"Account not found for account_id: {account_id}")
|
| 171 |
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
| 172 |
+
|
| 173 |
+
# 4. Generate AI Response
|
| 174 |
+
try:
|
| 175 |
+
response_text = await generate_llm_response(
|
| 176 |
+
message=message,
|
| 177 |
+
account=account,
|
| 178 |
+
rotator=state.gemini_rotator,
|
| 179 |
+
medical_context=medical_context
|
| 180 |
+
)
|
| 181 |
+
# If LLM fails, use a fallback
|
| 182 |
+
if not response_text:
|
| 183 |
+
logger().warning("LLM response failed, using fallback.")
|
| 184 |
+
response_text = _generate_fallback_response(message=message, account=account)
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger().error(f"Error generating medical response: {e}")
|
| 188 |
+
raise HTTPException(
|
| 189 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 190 |
+
detail="Failed to generate AI response."
|
| 191 |
+
) from e
|
| 192 |
+
|
| 193 |
+
# 5. Validate Model's Response
|
| 194 |
+
final_response = _validate_model_response(message, response_text, safety_guard)
|
| 195 |
+
|
| 196 |
+
# 6. Persist the Exchange (Asynchronously)
|
| 197 |
+
# This can be done in the background if it's not critical for the user response
|
| 198 |
+
await _persist_exchange(
|
| 199 |
+
state=state,
|
| 200 |
+
session_id=session_id,
|
| 201 |
+
patient_id=patient_id,
|
| 202 |
+
account_id=account_id,
|
| 203 |
+
question=message,
|
| 204 |
+
answer=final_response
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return final_response
|
| 208 |
|
| 209 |
def _generate_fallback_response(
|
| 210 |
message: str,
|