Spaces:
Sleeping
Sleeping
Initial CPS-API deployment with TxAgent integration
Browse files- analysis.py +280 -0
- api/routes/txagent.py +97 -74
- api/services/txagent_service.py +56 -107
- core/txagent_config.py +7 -12
- data/new_tool.json +1 -0
- db/mongo.py +27 -0
- requirements.txt +23 -5
- src/__init__.py +6 -0
- src/toolrag.py +67 -0
- src/txagent.py +154 -0
- src/utils.py +94 -0
- utils.py +110 -0
- voice.py +50 -0
analysis.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, List
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from config import agent, patients_collection, analysis_collection, alerts_collection, logger
|
| 4 |
+
from models import RiskLevel
|
| 5 |
+
from utils import (
|
| 6 |
+
structure_medical_response,
|
| 7 |
+
compute_file_content_hash,
|
| 8 |
+
compute_patient_data_hash,
|
| 9 |
+
serialize_patient,
|
| 10 |
+
broadcast_notification
|
| 11 |
+
)
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import re
|
| 16 |
+
import os
|
| 17 |
+
class NotificationType(str, Enum):
|
| 18 |
+
RISK_ALERT = "risk_alert"
|
| 19 |
+
SYSTEM = "system"
|
| 20 |
+
MESSAGE = "message"
|
| 21 |
+
|
| 22 |
+
class NotificationStatus(str, Enum):
|
| 23 |
+
UNREAD = "unread"
|
| 24 |
+
READ = "read"
|
| 25 |
+
ARCHIVED = "archived"
|
| 26 |
+
|
| 27 |
+
async def create_alert(patient_id: str, risk_data: dict):
|
| 28 |
+
try:
|
| 29 |
+
alert_doc = {
|
| 30 |
+
"patient_id": patient_id,
|
| 31 |
+
"type": "suicide_risk",
|
| 32 |
+
"level": risk_data["level"],
|
| 33 |
+
"score": risk_data["score"],
|
| 34 |
+
"factors": risk_data["factors"],
|
| 35 |
+
"timestamp": datetime.utcnow(),
|
| 36 |
+
"acknowledged": False,
|
| 37 |
+
"notification": {
|
| 38 |
+
"type": "risk_alert",
|
| 39 |
+
"status": "unread",
|
| 40 |
+
"title": f"Suicide Risk: {risk_data['level'].capitalize()}",
|
| 41 |
+
"message": f"Patient {patient_id} shows {risk_data['level']} risk factors",
|
| 42 |
+
"icon": "⚠️",
|
| 43 |
+
"action_url": f"/patient/{patient_id}/risk-assessment",
|
| 44 |
+
"priority": "high" if risk_data["level"] in ["high", "severe"] else "medium"
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
await alerts_collection.insert_one(alert_doc)
|
| 49 |
+
|
| 50 |
+
# Simplified WebSocket notification - remove Hugging Face specific code
|
| 51 |
+
await broadcast_notification(alert_doc["notification"])
|
| 52 |
+
|
| 53 |
+
logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}")
|
| 54 |
+
return alert_doc
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Failed to create alert: {str(e)}")
|
| 57 |
+
raise
|
| 58 |
+
async def analyze_patient_report(
|
| 59 |
+
patient_id: Optional[str],
|
| 60 |
+
report_content: str,
|
| 61 |
+
file_type: str,
|
| 62 |
+
file_content: bytes
|
| 63 |
+
):
|
| 64 |
+
"""Analyze a patient report and create alerts for risks"""
|
| 65 |
+
identifier = patient_id if patient_id else compute_file_content_hash(file_content)
|
| 66 |
+
report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
|
| 67 |
+
report_hash = compute_patient_data_hash(report_data)
|
| 68 |
+
logger.info(f"🧾 Analyzing report for identifier: {identifier}")
|
| 69 |
+
|
| 70 |
+
# Check for existing analysis
|
| 71 |
+
existing_analysis = await analysis_collection.find_one(
|
| 72 |
+
{"identifier": identifier, "report_hash": report_hash}
|
| 73 |
+
)
|
| 74 |
+
if existing_analysis:
|
| 75 |
+
logger.info(f"✅ No changes in report data for {identifier}, skipping analysis")
|
| 76 |
+
return existing_analysis
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# Generate analysis
|
| 80 |
+
prompt = (
|
| 81 |
+
"You are a clinical decision support AI. Analyze the following patient report:\n"
|
| 82 |
+
"1. Summarize the patient's medical history.\n"
|
| 83 |
+
"2. Identify risks or red flags (including mental health and suicide risk).\n"
|
| 84 |
+
"3. Highlight missed diagnoses or treatments.\n"
|
| 85 |
+
"4. Suggest next clinical steps.\n"
|
| 86 |
+
f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
raw_response = agent.chat(
|
| 90 |
+
message=prompt,
|
| 91 |
+
history=[],
|
| 92 |
+
temperature=0.7,
|
| 93 |
+
max_new_tokens=1024
|
| 94 |
+
)
|
| 95 |
+
structured_response = structure_medical_response(raw_response)
|
| 96 |
+
|
| 97 |
+
# Detect suicide risk
|
| 98 |
+
risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
|
| 99 |
+
suicide_risk = {
|
| 100 |
+
"level": risk_level.value,
|
| 101 |
+
"score": risk_score,
|
| 102 |
+
"factors": risk_factors
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Store analysis
|
| 106 |
+
analysis_doc = {
|
| 107 |
+
"identifier": identifier,
|
| 108 |
+
"patient_id": patient_id,
|
| 109 |
+
"timestamp": datetime.utcnow(),
|
| 110 |
+
"summary": structured_response,
|
| 111 |
+
"suicide_risk": suicide_risk,
|
| 112 |
+
"raw": raw_response,
|
| 113 |
+
"report_hash": report_hash,
|
| 114 |
+
"file_type": file_type
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
await analysis_collection.update_one(
|
| 118 |
+
{"identifier": identifier, "report_hash": report_hash},
|
| 119 |
+
{"$set": analysis_doc},
|
| 120 |
+
upsert=True
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Create alert if risk detected
|
| 124 |
+
if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
|
| 125 |
+
await create_alert(patient_id, suicide_risk)
|
| 126 |
+
|
| 127 |
+
logger.info(f"✅ Stored analysis for identifier {identifier}")
|
| 128 |
+
return analysis_doc
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error analyzing report for {identifier}: {str(e)}")
|
| 132 |
+
error_alert = {
|
| 133 |
+
"identifier": identifier,
|
| 134 |
+
"type": "system_error",
|
| 135 |
+
"level": "high",
|
| 136 |
+
"message": f"Report analysis failed: {str(e)}",
|
| 137 |
+
"timestamp": datetime.utcnow(),
|
| 138 |
+
"acknowledged": False,
|
| 139 |
+
"notification": {
|
| 140 |
+
"type": NotificationType.SYSTEM,
|
| 141 |
+
"status": NotificationStatus.UNREAD,
|
| 142 |
+
"title": "Report Analysis Error",
|
| 143 |
+
"message": f"Failed to analyze report for {'patient ' + patient_id if patient_id else 'unknown identifier'}",
|
| 144 |
+
"icon": "❌",
|
| 145 |
+
"action_url": "/system/errors",
|
| 146 |
+
"priority": "high"
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
await alerts_collection.insert_one(error_alert)
|
| 150 |
+
raise
|
| 151 |
+
|
| 152 |
+
async def analyze_patient(patient: dict):
|
| 153 |
+
"""Analyze complete patient record and create alerts for risks"""
|
| 154 |
+
try:
|
| 155 |
+
serialized = serialize_patient(patient)
|
| 156 |
+
patient_id = serialized.get("fhir_id")
|
| 157 |
+
patient_hash = compute_patient_data_hash(serialized)
|
| 158 |
+
logger.info(f"🧾 Analyzing patient: {patient_id}")
|
| 159 |
+
|
| 160 |
+
# Check for existing analysis
|
| 161 |
+
existing_analysis = await analysis_collection.find_one({"patient_id": patient_id})
|
| 162 |
+
if existing_analysis and existing_analysis.get("data_hash") == patient_hash:
|
| 163 |
+
logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis")
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
# Generate analysis
|
| 167 |
+
doc = json.dumps(serialized, indent=2)
|
| 168 |
+
message = (
|
| 169 |
+
"You are a clinical decision support AI.\n\n"
|
| 170 |
+
"Given the patient document below:\n"
|
| 171 |
+
"1. Summarize the patient's medical history.\n"
|
| 172 |
+
"2. Identify risks or red flags (including mental health and suicide risk).\n"
|
| 173 |
+
"3. Highlight missed diagnoses or treatments.\n"
|
| 174 |
+
"4. Suggest next clinical steps.\n"
|
| 175 |
+
f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
|
| 179 |
+
structured = structure_medical_response(raw)
|
| 180 |
+
|
| 181 |
+
# Detect suicide risk
|
| 182 |
+
risk_level, risk_score, risk_factors = detect_suicide_risk(raw)
|
| 183 |
+
suicide_risk = {
|
| 184 |
+
"level": risk_level.value,
|
| 185 |
+
"score": risk_score,
|
| 186 |
+
"factors": risk_factors
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Store analysis
|
| 190 |
+
analysis_doc = {
|
| 191 |
+
"identifier": patient_id,
|
| 192 |
+
"patient_id": patient_id,
|
| 193 |
+
"timestamp": datetime.utcnow(),
|
| 194 |
+
"summary": structured,
|
| 195 |
+
"suicide_risk": suicide_risk,
|
| 196 |
+
"raw": raw,
|
| 197 |
+
"data_hash": patient_hash
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
await analysis_collection.update_one(
|
| 201 |
+
{"identifier": patient_id},
|
| 202 |
+
{"$set": analysis_doc},
|
| 203 |
+
upsert=True
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Create alert if risk detected
|
| 207 |
+
if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
|
| 208 |
+
await create_alert(patient_id, suicide_risk)
|
| 209 |
+
|
| 210 |
+
logger.info(f"✅ Stored analysis for patient {patient_id}")
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.error(f"Error analyzing patient: {str(e)}")
|
| 214 |
+
error_alert = {
|
| 215 |
+
"patient_id": patient_id if 'patient_id' in locals() else "unknown",
|
| 216 |
+
"type": "system_error",
|
| 217 |
+
"level": "high",
|
| 218 |
+
"message": f"Patient analysis failed: {str(e)}",
|
| 219 |
+
"timestamp": datetime.utcnow(),
|
| 220 |
+
"acknowledged": False,
|
| 221 |
+
"notification": {
|
| 222 |
+
"type": NotificationType.SYSTEM,
|
| 223 |
+
"status": NotificationStatus.UNREAD,
|
| 224 |
+
"title": "Analysis Error",
|
| 225 |
+
"message": f"Failed to analyze patient {patient_id if 'patient_id' in locals() else 'unknown'}",
|
| 226 |
+
"icon": "❌",
|
| 227 |
+
"action_url": "/system/errors",
|
| 228 |
+
"priority": "high"
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
await alerts_collection.insert_one(error_alert)
|
| 232 |
+
raise
|
| 233 |
+
|
| 234 |
+
def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
|
| 235 |
+
"""Detect suicide risk level from text analysis"""
|
| 236 |
+
suicide_keywords = [
|
| 237 |
+
'suicide', 'suicidal', 'kill myself', 'end my life',
|
| 238 |
+
'want to die', 'self-harm', 'self harm', 'hopeless',
|
| 239 |
+
'no reason to live', 'plan to die'
|
| 240 |
+
]
|
| 241 |
+
explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
|
| 242 |
+
if not explicit_mentions:
|
| 243 |
+
return RiskLevel.NONE, 0.0, []
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Get AI assessment
|
| 247 |
+
assessment_prompt = (
|
| 248 |
+
"Assess the suicide risk level based on this text. "
|
| 249 |
+
"Consider frequency, specificity, and severity of statements. "
|
| 250 |
+
"Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", "
|
| 251 |
+
"\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n"
|
| 252 |
+
f"Text to assess:\n{text}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
response = agent.chat(
|
| 256 |
+
message=assessment_prompt,
|
| 257 |
+
history=[],
|
| 258 |
+
temperature=0.2,
|
| 259 |
+
max_new_tokens=256
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Parse response
|
| 263 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 264 |
+
if json_match:
|
| 265 |
+
assessment = json.loads(json_match.group())
|
| 266 |
+
return (
|
| 267 |
+
RiskLevel(assessment.get("risk_level", "none").lower()),
|
| 268 |
+
float(assessment.get("risk_score", 0)),
|
| 269 |
+
assessment.get("factors", [])
|
| 270 |
+
)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"Error in suicide risk assessment: {e}")
|
| 273 |
+
|
| 274 |
+
# Fallback heuristic if AI assessment fails
|
| 275 |
+
risk_score = min(0.1 * len(explicit_mentions), 0.9)
|
| 276 |
+
if risk_score > 0.7:
|
| 277 |
+
return RiskLevel.HIGH, risk_score, explicit_mentions
|
| 278 |
+
elif risk_score > 0.4:
|
| 279 |
+
return RiskLevel.MODERATE, risk_score, explicit_mentions
|
| 280 |
+
return RiskLevel.LOW, risk_score, explicit_mentions
|
api/routes/txagent.py
CHANGED
|
@@ -1,57 +1,99 @@
|
|
| 1 |
-
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query
|
| 2 |
-
from fastapi.responses import StreamingResponse
|
|
|
|
| 3 |
from typing import Optional, List
|
| 4 |
from pydantic import BaseModel
|
| 5 |
from core.security import get_current_user
|
| 6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import logging
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
class ChatRequest(BaseModel):
|
| 14 |
message: str
|
| 15 |
history: Optional[List[dict]] = None
|
|
|
|
|
|
|
|
|
|
| 16 |
patient_id: Optional[str] = None
|
| 17 |
|
| 18 |
class VoiceOutputRequest(BaseModel):
|
| 19 |
text: str
|
| 20 |
language: str = "en-US"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@router.get("/txagent/status")
|
| 23 |
-
async def
|
| 24 |
-
"
|
| 25 |
-
|
| 26 |
-
status
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
}
|
| 32 |
-
except Exception as e:
|
| 33 |
-
logger.error(f"Error getting TxAgent status: {e}")
|
| 34 |
-
raise HTTPException(status_code=500, detail="Failed to get TxAgent status")
|
| 35 |
|
| 36 |
@router.get("/txagent/patients/analysis-results")
|
| 37 |
async def get_patient_analysis_results(
|
| 38 |
name: Optional[str] = Query(None),
|
| 39 |
current_user: dict = Depends(get_current_user)
|
| 40 |
):
|
| 41 |
-
"
|
| 42 |
try:
|
| 43 |
# Check if user has appropriate permissions
|
| 44 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 45 |
raise HTTPException(status_code=403, detail="Only doctors and admins can access analysis results")
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
except Exception as e:
|
| 53 |
-
logger.error(f"Error
|
| 54 |
-
# Return empty array instead of throwing error to prevent 500
|
| 55 |
return []
|
| 56 |
|
| 57 |
@router.post("/txagent/chat")
|
|
@@ -59,22 +101,19 @@ async def chat_with_txagent(
|
|
| 59 |
request: ChatRequest,
|
| 60 |
current_user: dict = Depends(get_current_user)
|
| 61 |
):
|
| 62 |
-
"""Chat avec TxAgent"""
|
| 63 |
try:
|
| 64 |
# Vérifier que l'utilisateur est médecin ou admin
|
| 65 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 66 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use TxAgent")
|
| 67 |
|
| 68 |
-
response
|
| 69 |
-
|
| 70 |
-
history=request.history,
|
| 71 |
-
patient_id=request.patient_id
|
| 72 |
-
)
|
| 73 |
|
| 74 |
return {
|
| 75 |
"status": "success",
|
| 76 |
"response": response,
|
| 77 |
-
"mode":
|
| 78 |
}
|
| 79 |
except Exception as e:
|
| 80 |
logger.error(f"Error in TxAgent chat: {e}")
|
|
@@ -85,18 +124,16 @@ async def transcribe_audio(
|
|
| 85 |
audio: UploadFile = File(...),
|
| 86 |
current_user: dict = Depends(get_current_user)
|
| 87 |
):
|
| 88 |
-
"""Transcription vocale avec TxAgent"""
|
| 89 |
try:
|
| 90 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 91 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
|
| 92 |
|
| 93 |
-
|
| 94 |
-
result = await txagent_service.voice_transcribe(audio_data)
|
| 95 |
-
|
| 96 |
return {
|
| 97 |
"status": "success",
|
| 98 |
-
"transcription":
|
| 99 |
-
"mode":
|
| 100 |
}
|
| 101 |
except Exception as e:
|
| 102 |
logger.error(f"Error in voice transcription: {e}")
|
|
@@ -107,63 +144,49 @@ async def synthesize_speech(
|
|
| 107 |
request: VoiceOutputRequest,
|
| 108 |
current_user: dict = Depends(get_current_user)
|
| 109 |
):
|
| 110 |
-
"""Synthèse vocale avec TxAgent"""
|
| 111 |
try:
|
| 112 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 113 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
language=request.language
|
| 118 |
-
)
|
| 119 |
|
| 120 |
return StreamingResponse(
|
| 121 |
iter([audio_data]),
|
| 122 |
media_type="audio/mpeg",
|
| 123 |
-
headers={
|
| 124 |
-
"Content-Disposition": "attachment; filename=synthesized_speech.mp3"
|
| 125 |
-
}
|
| 126 |
)
|
| 127 |
except Exception as e:
|
| 128 |
logger.error(f"Error in voice synthesis: {e}")
|
| 129 |
raise HTTPException(status_code=500, detail="Failed to synthesize speech")
|
| 130 |
|
| 131 |
-
@router.post("/txagent/patients/analyze")
|
| 132 |
-
async def analyze_patient_data(
|
| 133 |
-
patient_data: dict,
|
| 134 |
-
current_user: dict = Depends(get_current_user)
|
| 135 |
-
):
|
| 136 |
-
"""Analyse de données patient avec TxAgent"""
|
| 137 |
-
try:
|
| 138 |
-
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 139 |
-
raise HTTPException(status_code=403, detail="Only doctors and admins can use analysis features")
|
| 140 |
-
|
| 141 |
-
analysis = await txagent_service.analyze_patient(patient_data)
|
| 142 |
-
|
| 143 |
-
return {
|
| 144 |
-
"status": "success",
|
| 145 |
-
"analysis": analysis,
|
| 146 |
-
"mode": txagent_service.config.get_txagent_mode()
|
| 147 |
-
}
|
| 148 |
-
except Exception as e:
|
| 149 |
-
logger.error(f"Error in patient analysis: {e}")
|
| 150 |
-
raise HTTPException(status_code=500, detail="Failed to analyze patient data")
|
| 151 |
-
|
| 152 |
@router.get("/txagent/chats")
|
| 153 |
async def get_chats(current_user: dict = Depends(get_current_user)):
|
| 154 |
"""Obtient l'historique des chats"""
|
| 155 |
try:
|
| 156 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 157 |
-
raise HTTPException(status_code=403, detail="Only doctors and admins can access
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
except Exception as e:
|
| 168 |
logger.error(f"Error getting chats: {e}")
|
| 169 |
raise HTTPException(status_code=500, detail="Failed to get chats")
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Path
|
| 2 |
+
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
| 3 |
+
from fastapi.encoders import jsonable_encoder
|
| 4 |
from typing import Optional, List
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from core.security import get_current_user
|
| 7 |
+
from utils import clean_text_response
|
| 8 |
+
from analysis import analyze_patient_report
|
| 9 |
+
from voice import recognize_speech, text_to_speech, extract_text_from_pdf
|
| 10 |
+
from docx import Document
|
| 11 |
+
import re
|
| 12 |
+
import io
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from bson import ObjectId
|
| 15 |
+
import asyncio
|
| 16 |
+
from bson.errors import InvalidId
|
| 17 |
+
import base64
|
| 18 |
+
import os
|
| 19 |
+
from pathlib import Path as PathLib
|
| 20 |
+
import tempfile
|
| 21 |
+
import subprocess
|
| 22 |
import logging
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
+
# Define the ChatRequest model with an optional patient_id
|
|
|
|
| 27 |
class ChatRequest(BaseModel):
|
| 28 |
message: str
|
| 29 |
history: Optional[List[dict]] = None
|
| 30 |
+
format: Optional[str] = "clean"
|
| 31 |
+
temperature: Optional[float] = 0.7
|
| 32 |
+
max_new_tokens: Optional[int] = 512
|
| 33 |
patient_id: Optional[str] = None
|
| 34 |
|
| 35 |
class VoiceOutputRequest(BaseModel):
|
| 36 |
text: str
|
| 37 |
language: str = "en-US"
|
| 38 |
+
slow: bool = False
|
| 39 |
+
return_format: str = "mp3"
|
| 40 |
+
|
| 41 |
+
class RiskLevel(BaseModel):
|
| 42 |
+
level: str
|
| 43 |
+
score: float
|
| 44 |
+
factors: Optional[List[str]] = None
|
| 45 |
+
|
| 46 |
+
router = APIRouter()
|
| 47 |
|
| 48 |
@router.get("/txagent/status")
|
| 49 |
+
async def status(current_user: dict = Depends(get_current_user)):
|
| 50 |
+
logger.info(f"Status endpoint accessed by {current_user['email']}")
|
| 51 |
+
return {
|
| 52 |
+
"status": "running",
|
| 53 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 54 |
+
"version": "2.6.0",
|
| 55 |
+
"features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload", "patient-reports-pdf", "all-patients-reports-pdf"]
|
| 56 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
@router.get("/txagent/patients/analysis-results")
|
| 59 |
async def get_patient_analysis_results(
|
| 60 |
name: Optional[str] = Query(None),
|
| 61 |
current_user: dict = Depends(get_current_user)
|
| 62 |
):
|
| 63 |
+
logger.info(f"Fetching analysis results by {current_user['email']}")
|
| 64 |
try:
|
| 65 |
# Check if user has appropriate permissions
|
| 66 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 67 |
raise HTTPException(status_code=403, detail="Only doctors and admins can access analysis results")
|
| 68 |
|
| 69 |
+
# Import database collections
|
| 70 |
+
from db.mongo import db
|
| 71 |
+
patients_collection = db.patients
|
| 72 |
+
analysis_collection = db.patient_analysis_results
|
| 73 |
|
| 74 |
+
query = {}
|
| 75 |
+
if name:
|
| 76 |
+
name_regex = re.compile(name, re.IGNORECASE)
|
| 77 |
+
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None)
|
| 78 |
+
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p]
|
| 79 |
+
if not patient_ids:
|
| 80 |
+
return []
|
| 81 |
+
query = {"patient_id": {"$in": patient_ids}}
|
| 82 |
+
|
| 83 |
+
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100)
|
| 84 |
+
enriched_results = []
|
| 85 |
+
for analysis in analyses:
|
| 86 |
+
patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")})
|
| 87 |
+
if not patient:
|
| 88 |
+
continue # Skip if patient no longer exists
|
| 89 |
+
analysis["full_name"] = patient.get("full_name", "Unknown")
|
| 90 |
+
analysis["_id"] = str(analysis["_id"])
|
| 91 |
+
enriched_results.append(analysis)
|
| 92 |
+
|
| 93 |
+
return enriched_results
|
| 94 |
+
|
| 95 |
except Exception as e:
|
| 96 |
+
logger.error(f"Error fetching analysis results: {e}")
|
|
|
|
| 97 |
return []
|
| 98 |
|
| 99 |
@router.post("/txagent/chat")
|
|
|
|
| 101 |
request: ChatRequest,
|
| 102 |
current_user: dict = Depends(get_current_user)
|
| 103 |
):
|
| 104 |
+
"""Chat avec TxAgent intégré"""
|
| 105 |
try:
|
| 106 |
# Vérifier que l'utilisateur est médecin ou admin
|
| 107 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 108 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use TxAgent")
|
| 109 |
|
| 110 |
+
# For now, return a simple response since the full TxAgent is not yet implemented
|
| 111 |
+
response = f"TxAgent integrated response: {request.message}"
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
return {
|
| 114 |
"status": "success",
|
| 115 |
"response": response,
|
| 116 |
+
"mode": "integrated"
|
| 117 |
}
|
| 118 |
except Exception as e:
|
| 119 |
logger.error(f"Error in TxAgent chat: {e}")
|
|
|
|
| 124 |
audio: UploadFile = File(...),
|
| 125 |
current_user: dict = Depends(get_current_user)
|
| 126 |
):
|
| 127 |
+
"""Transcription vocale avec TxAgent intégré"""
|
| 128 |
try:
|
| 129 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 130 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
|
| 131 |
|
| 132 |
+
# For now, return mock transcription
|
|
|
|
|
|
|
| 133 |
return {
|
| 134 |
"status": "success",
|
| 135 |
+
"transcription": "Mock voice transcription from integrated TxAgent",
|
| 136 |
+
"mode": "integrated"
|
| 137 |
}
|
| 138 |
except Exception as e:
|
| 139 |
logger.error(f"Error in voice transcription: {e}")
|
|
|
|
| 144 |
request: VoiceOutputRequest,
|
| 145 |
current_user: dict = Depends(get_current_user)
|
| 146 |
):
|
| 147 |
+
"""Synthèse vocale avec TxAgent intégré"""
|
| 148 |
try:
|
| 149 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 150 |
raise HTTPException(status_code=403, detail="Only doctors and admins can use voice features")
|
| 151 |
|
| 152 |
+
# For now, return mock audio data
|
| 153 |
+
audio_data = b"Mock audio data from integrated TxAgent"
|
|
|
|
|
|
|
| 154 |
|
| 155 |
return StreamingResponse(
|
| 156 |
iter([audio_data]),
|
| 157 |
media_type="audio/mpeg",
|
| 158 |
+
headers={"Content-Disposition": "attachment; filename=speech.mp3"}
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
except Exception as e:
|
| 161 |
logger.error(f"Error in voice synthesis: {e}")
|
| 162 |
raise HTTPException(status_code=500, detail="Failed to synthesize speech")
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
@router.get("/txagent/chats")
|
| 165 |
async def get_chats(current_user: dict = Depends(get_current_user)):
|
| 166 |
"""Obtient l'historique des chats"""
|
| 167 |
try:
|
| 168 |
if not any(role in current_user.get('roles', []) for role in ['doctor', 'admin']):
|
| 169 |
+
raise HTTPException(status_code=403, detail="Only doctors and admins can access chat history")
|
| 170 |
|
| 171 |
+
# Import database collections
|
| 172 |
+
from db.mongo import db
|
| 173 |
+
chats_collection = db.chats
|
| 174 |
|
| 175 |
+
# Query local database for chat history
|
| 176 |
+
cursor = chats_collection.find().sort("timestamp", -1).limit(50)
|
| 177 |
+
chats = await cursor.to_list(length=50)
|
| 178 |
+
|
| 179 |
+
return [
|
| 180 |
+
{
|
| 181 |
+
"id": str(chat["_id"]),
|
| 182 |
+
"message": chat.get("message", ""),
|
| 183 |
+
"response": chat.get("response", ""),
|
| 184 |
+
"timestamp": chat.get("timestamp"),
|
| 185 |
+
"user_id": str(chat.get("user_id", "")),
|
| 186 |
+
"patient_id": str(chat.get("patient_id", "")) if chat.get("patient_id") else None
|
| 187 |
+
}
|
| 188 |
+
for chat in chats
|
| 189 |
+
]
|
| 190 |
except Exception as e:
|
| 191 |
logger.error(f"Error getting chats: {e}")
|
| 192 |
raise HTTPException(status_code=500, detail="Failed to get chats")
|
api/services/txagent_service.py
CHANGED
|
@@ -1,139 +1,88 @@
|
|
| 1 |
-
import aiohttp
|
| 2 |
-
import asyncio
|
| 3 |
import logging
|
| 4 |
from typing import Optional, Dict, Any, List
|
| 5 |
from core.txagent_config import txagent_config
|
|
|
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
class TxAgentService:
|
| 10 |
def __init__(self):
|
| 11 |
self.config = txagent_config
|
| 12 |
-
self.session = None
|
| 13 |
-
|
| 14 |
-
async def _get_session(self):
|
| 15 |
-
"""Obtient ou crée une session HTTP"""
|
| 16 |
-
if self.session is None:
|
| 17 |
-
self.session = aiohttp.ClientSession()
|
| 18 |
-
return self.session
|
| 19 |
-
|
| 20 |
-
async def _make_request(self, endpoint: str, method: str = "GET", data: Optional[Dict] = None) -> Dict[str, Any]:
|
| 21 |
-
"""Fait une requête vers le service TxAgent avec fallback"""
|
| 22 |
-
session = await self._get_session()
|
| 23 |
-
url = f"{self.config.get_txagent_url()}{endpoint}"
|
| 24 |
-
|
| 25 |
-
try:
|
| 26 |
-
if method.upper() == "GET":
|
| 27 |
-
async with session.get(url) as response:
|
| 28 |
-
return await response.json()
|
| 29 |
-
elif method.upper() == "POST":
|
| 30 |
-
async with session.post(url, json=data) as response:
|
| 31 |
-
return await response.json()
|
| 32 |
-
except Exception as e:
|
| 33 |
-
logger.error(f"Error calling TxAgent service: {e}")
|
| 34 |
-
# Fallback vers cloud si local échoue
|
| 35 |
-
if self.config.get_txagent_mode() == "local":
|
| 36 |
-
logger.info("Falling back to cloud TxAgent service")
|
| 37 |
-
self.config.mode = "cloud"
|
| 38 |
-
return await self._make_request(endpoint, method, data)
|
| 39 |
-
else:
|
| 40 |
-
raise
|
| 41 |
|
| 42 |
async def chat(self, message: str, history: Optional[list] = None, patient_id: Optional[str] = None) -> Dict[str, Any]:
|
| 43 |
-
"""Service de chat avec TxAgent"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
}
|
| 49 |
-
return await self._make_request("/chat", "POST", data)
|
| 50 |
|
| 51 |
async def analyze_patient(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 52 |
-
"""Analyse de données patient avec TxAgent"""
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
async def voice_transcribe(self, audio_data: bytes) -> Dict[str, Any]:
|
| 56 |
-
"""Transcription vocale avec TxAgent"""
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
form_data.add_field('audio', audio_data, filename='audio.wav')
|
| 63 |
-
|
| 64 |
-
async with session.post(url, data=form_data) as response:
|
| 65 |
-
return await response.json()
|
| 66 |
-
except Exception as e:
|
| 67 |
-
logger.error(f"Error in voice transcription: {e}")
|
| 68 |
-
if self.config.get_txagent_mode() == "local":
|
| 69 |
-
self.config.mode = "cloud"
|
| 70 |
-
return await self.voice_transcribe(audio_data)
|
| 71 |
-
else:
|
| 72 |
-
raise
|
| 73 |
|
| 74 |
async def voice_synthesize(self, text: str, language: str = "en-US") -> bytes:
|
| 75 |
-
"""Synthèse vocale avec TxAgent"""
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
try:
|
| 80 |
-
data = {
|
| 81 |
-
"text": text,
|
| 82 |
-
"language": language,
|
| 83 |
-
"return_format": "mp3"
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
async with session.post(url, json=data) as response:
|
| 87 |
-
return await response.read()
|
| 88 |
-
except Exception as e:
|
| 89 |
-
logger.error(f"Error in voice synthesis: {e}")
|
| 90 |
-
if self.config.get_txagent_mode() == "local":
|
| 91 |
-
self.config.mode = "cloud"
|
| 92 |
-
return await self.voice_synthesize(text, language)
|
| 93 |
-
else:
|
| 94 |
-
raise
|
| 95 |
|
| 96 |
async def get_status(self) -> Dict[str, Any]:
|
| 97 |
-
"""Obtient le statut du service TxAgent"""
|
| 98 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
async def get_analysis_results(self, name: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 101 |
-
"""Get patient analysis results from TxAgent service"""
|
| 102 |
try:
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
params['name'] = name
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
if params:
|
| 111 |
-
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
|
| 112 |
-
endpoint = f"{endpoint}?{query_string}"
|
| 113 |
|
| 114 |
-
return
|
| 115 |
except Exception as e:
|
| 116 |
-
logger.
|
| 117 |
-
# Return empty results if external API is not available
|
| 118 |
-
# In a real implementation, you would query your local database
|
| 119 |
return []
|
| 120 |
|
| 121 |
async def get_chats(self) -> List[Dict[str, Any]]:
|
| 122 |
-
"""Obtient l'historique des chats"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
# Instance globale
|
| 139 |
txagent_service = TxAgentService()
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from typing import Optional, Dict, Any, List
|
| 3 |
from core.txagent_config import txagent_config
|
| 4 |
+
from db.mongo import db
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
class TxAgentService:
|
| 9 |
def __init__(self):
|
| 10 |
self.config = txagent_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
async def chat(self, message: str, history: Optional[list] = None, patient_id: Optional[str] = None) -> Dict[str, Any]:
|
| 13 |
+
"""Service de chat avec TxAgent intégré"""
|
| 14 |
+
# For now, return a simple response since the full TxAgent is not yet implemented
|
| 15 |
+
return {
|
| 16 |
+
"response": f"TxAgent integrated response: {message}",
|
| 17 |
+
"status": "success"
|
| 18 |
}
|
|
|
|
| 19 |
|
| 20 |
async def analyze_patient(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 21 |
+
"""Analyse de données patient avec TxAgent intégré"""
|
| 22 |
+
# For now, return mock analysis
|
| 23 |
+
return {
|
| 24 |
+
"analysis": "Mock patient analysis from integrated TxAgent",
|
| 25 |
+
"status": "success"
|
| 26 |
+
}
|
| 27 |
|
| 28 |
async def voice_transcribe(self, audio_data: bytes) -> Dict[str, Any]:
|
| 29 |
+
"""Transcription vocale avec TxAgent intégré"""
|
| 30 |
+
# For now, return mock transcription
|
| 31 |
+
return {
|
| 32 |
+
"transcription": "Mock voice transcription from integrated TxAgent",
|
| 33 |
+
"status": "success"
|
| 34 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
async def voice_synthesize(self, text: str, language: str = "en-US") -> bytes:
|
| 37 |
+
"""Synthèse vocale avec TxAgent intégré"""
|
| 38 |
+
# For now, return mock audio data
|
| 39 |
+
return b"Mock audio data from integrated TxAgent"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
async def get_status(self) -> Dict[str, Any]:
|
| 42 |
+
"""Obtient le statut du service TxAgent intégré"""
|
| 43 |
+
return {
|
| 44 |
+
"status": "running",
|
| 45 |
+
"mode": "integrated",
|
| 46 |
+
"version": "2.6.0"
|
| 47 |
+
}
|
| 48 |
|
| 49 |
async def get_analysis_results(self, name: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 50 |
+
"""Get patient analysis results from integrated TxAgent service"""
|
| 51 |
try:
|
| 52 |
+
# Since TxAgent is integrated, we can query the local database directly
|
| 53 |
+
# For now, return empty results until the full TxAgent is implemented
|
| 54 |
+
logger.info(f"Getting analysis results for name: {name}")
|
|
|
|
| 55 |
|
| 56 |
+
# TODO: Implement actual analysis results query from local database
|
| 57 |
+
# This would typically query the analysis_collection in MongoDB
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
return []
|
| 60 |
except Exception as e:
|
| 61 |
+
logger.error(f"Error getting analysis results from integrated TxAgent: {e}")
|
|
|
|
|
|
|
| 62 |
return []
|
| 63 |
|
| 64 |
async def get_chats(self) -> List[Dict[str, Any]]:
|
| 65 |
+
"""Obtient l'historique des chats depuis le service intégré"""
|
| 66 |
+
try:
|
| 67 |
+
# Query local database for chat history
|
| 68 |
+
chats_collection = db.chats
|
| 69 |
+
cursor = chats_collection.find().sort("timestamp", -1).limit(50)
|
| 70 |
+
chats = await cursor.to_list(length=50)
|
| 71 |
+
|
| 72 |
+
return [
|
| 73 |
+
{
|
| 74 |
+
"id": str(chat["_id"]),
|
| 75 |
+
"message": chat.get("message", ""),
|
| 76 |
+
"response": chat.get("response", ""),
|
| 77 |
+
"timestamp": chat.get("timestamp"),
|
| 78 |
+
"user_id": str(chat.get("user_id", "")),
|
| 79 |
+
"patient_id": str(chat.get("patient_id", "")) if chat.get("patient_id") else None
|
| 80 |
+
}
|
| 81 |
+
for chat in chats
|
| 82 |
+
]
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Error getting chats from integrated service: {e}")
|
| 85 |
+
return []
|
| 86 |
|
| 87 |
# Instance globale
|
| 88 |
txagent_service = TxAgentService()
|
core/txagent_config.py
CHANGED
|
@@ -6,9 +6,10 @@ logger = logging.getLogger(__name__)
|
|
| 6 |
|
| 7 |
class TxAgentConfig:
|
| 8 |
def __init__(self):
|
|
|
|
| 9 |
self.mode = os.getenv("TXAGENT_MODE", "local") # local, cloud, hybrid
|
| 10 |
self.cloud_url = os.getenv("TXAGENT_CLOUD_URL", "https://rocketfarmstudios-txagent-api.hf.space")
|
| 11 |
-
self.local_enabled = os.getenv("TXAGENT_LOCAL_ENABLED", "
|
| 12 |
self.gpu_available = self._check_gpu_availability()
|
| 13 |
|
| 14 |
def _check_gpu_availability(self) -> bool:
|
|
@@ -21,23 +22,17 @@ class TxAgentConfig:
|
|
| 21 |
|
| 22 |
def get_txagent_mode(self) -> str:
|
| 23 |
"""Détermine le mode optimal pour TxAgent"""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
elif self.mode == "local" and self.local_enabled and self.gpu_available:
|
| 27 |
-
return "local"
|
| 28 |
-
else:
|
| 29 |
-
return "cloud" # Fallback vers cloud
|
| 30 |
|
| 31 |
def get_txagent_url(self) -> str:
|
| 32 |
"""Retourne l'URL du service TxAgent"""
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
else:
|
| 36 |
-
return self.cloud_url
|
| 37 |
|
| 38 |
def is_local_available(self) -> bool:
|
| 39 |
"""Vérifie si le mode local est disponible"""
|
| 40 |
-
return
|
| 41 |
|
| 42 |
# Instance globale
|
| 43 |
txagent_config = TxAgentConfig()
|
|
|
|
| 6 |
|
| 7 |
class TxAgentConfig:
|
| 8 |
def __init__(self):
|
| 9 |
+
# Since TxAgent is now integrated, default to local mode
|
| 10 |
self.mode = os.getenv("TXAGENT_MODE", "local") # local, cloud, hybrid
|
| 11 |
self.cloud_url = os.getenv("TXAGENT_CLOUD_URL", "https://rocketfarmstudios-txagent-api.hf.space")
|
| 12 |
+
self.local_enabled = os.getenv("TXAGENT_LOCAL_ENABLED", "true").lower() == "true"
|
| 13 |
self.gpu_available = self._check_gpu_availability()
|
| 14 |
|
| 15 |
def _check_gpu_availability(self) -> bool:
|
|
|
|
| 22 |
|
| 23 |
def get_txagent_mode(self) -> str:
|
| 24 |
"""Détermine le mode optimal pour TxAgent"""
|
| 25 |
+
# Since TxAgent is integrated, always use local mode
|
| 26 |
+
return "local"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def get_txagent_url(self) -> str:
|
| 29 |
"""Retourne l'URL du service TxAgent"""
|
| 30 |
+
# Since TxAgent is integrated, return localhost
|
| 31 |
+
return "http://localhost:7860" # Same port as the main API
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def is_local_available(self) -> bool:
|
| 34 |
"""Vérifie si le mode local est disponible"""
|
| 35 |
+
return True # Always available since it's integrated
|
| 36 |
|
| 37 |
# Instance globale
|
| 38 |
txagent_config = TxAgentConfig()
|
data/new_tool.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
db/mongo.py
CHANGED
|
@@ -15,6 +15,12 @@ appointments_collection = db.appointments
|
|
| 15 |
messages_collection = db.messages
|
| 16 |
password_reset_codes_collection = db.password_reset_codes
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Create indexes for better duplicate detection
|
| 19 |
async def create_indexes():
|
| 20 |
"""Create database indexes for better performance and duplicate detection"""
|
|
@@ -51,6 +57,27 @@ async def create_indexes():
|
|
| 51 |
("source", 1)
|
| 52 |
])
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
print("Database indexes created successfully")
|
| 55 |
|
| 56 |
except Exception as e:
|
|
|
|
| 15 |
messages_collection = db.messages
|
| 16 |
password_reset_codes_collection = db.password_reset_codes
|
| 17 |
|
| 18 |
+
# TxAgent Collections
|
| 19 |
+
patient_analysis_results_collection = db.patient_analysis_results
|
| 20 |
+
chats_collection = db.chats
|
| 21 |
+
clinical_alerts_collection = db.clinical_alerts
|
| 22 |
+
notifications_collection = db.notifications
|
| 23 |
+
|
| 24 |
# Create indexes for better duplicate detection
|
| 25 |
async def create_indexes():
|
| 26 |
"""Create database indexes for better performance and duplicate detection"""
|
|
|
|
| 57 |
("source", 1)
|
| 58 |
])
|
| 59 |
|
| 60 |
+
# TxAgent indexes
|
| 61 |
+
await patient_analysis_results_collection.create_index([
|
| 62 |
+
("patient_id", 1),
|
| 63 |
+
("timestamp", -1)
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
await chats_collection.create_index([
|
| 67 |
+
("user_id", 1),
|
| 68 |
+
("timestamp", -1)
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
await clinical_alerts_collection.create_index([
|
| 72 |
+
("patient_id", 1),
|
| 73 |
+
("timestamp", -1)
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
await notifications_collection.create_index([
|
| 77 |
+
("user_id", 1),
|
| 78 |
+
("timestamp", -1)
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
print("Database indexes created successfully")
|
| 82 |
|
| 83 |
except Exception as e:
|
requirements.txt
CHANGED
|
@@ -1,15 +1,33 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn
|
| 3 |
motor
|
| 4 |
python-jose[cryptography]
|
| 5 |
passlib[bcrypt]
|
| 6 |
certifi
|
| 7 |
bcrypt==4.0.1
|
| 8 |
email-validator
|
| 9 |
-
python-multipart
|
| 10 |
requests
|
| 11 |
gradio
|
| 12 |
python-dotenv>=0.21.0
|
| 13 |
-
aiohttp
|
| 14 |
fastapi-mail
|
| 15 |
-
jinja2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.68.0
|
| 2 |
+
uvicorn>=0.15.0
|
| 3 |
motor
|
| 4 |
python-jose[cryptography]
|
| 5 |
passlib[bcrypt]
|
| 6 |
certifi
|
| 7 |
bcrypt==4.0.1
|
| 8 |
email-validator
|
| 9 |
+
python-multipart>=0.0.5
|
| 10 |
requests
|
| 11 |
gradio
|
| 12 |
python-dotenv>=0.21.0
|
|
|
|
| 13 |
fastapi-mail
|
| 14 |
+
jinja2
|
| 15 |
+
pandas>=1.3.0
|
| 16 |
+
pdfplumber>=0.6.0
|
| 17 |
+
fpdf2>=2.5.5
|
| 18 |
+
matplotlib>=3.4.0
|
| 19 |
+
transformers>=4.36.0
|
| 20 |
+
sentence-transformers>=2.2.2
|
| 21 |
+
accelerate>=0.24.1
|
| 22 |
+
tooluniverse
|
| 23 |
+
markdown
|
| 24 |
+
PyPDF2
|
| 25 |
+
pymongo
|
| 26 |
+
SpeechRecognition
|
| 27 |
+
gTTS
|
| 28 |
+
pydub
|
| 29 |
+
fitz
|
| 30 |
+
python-docx
|
| 31 |
+
pyfcm
|
| 32 |
+
httpx
|
| 33 |
+
jwt
|
src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .txagent import TxAgent
|
| 2 |
+
from .toolrag import ToolRAGModel
|
| 3 |
+
__all__ = [
|
| 4 |
+
"TxAgent",
|
| 5 |
+
"ToolRAGModel",
|
| 6 |
+
]
|
src/toolrag.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from .utils import get_md5
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ToolRAGModel:
|
| 9 |
+
def __init__(self, rag_model_name):
|
| 10 |
+
self.rag_model_name = rag_model_name
|
| 11 |
+
self.rag_model = None
|
| 12 |
+
self.tool_desc_embedding = None
|
| 13 |
+
self.tool_name = None
|
| 14 |
+
self.tool_embedding_path = None
|
| 15 |
+
self.load_rag_model()
|
| 16 |
+
|
| 17 |
+
def load_rag_model(self):
|
| 18 |
+
self.rag_model = SentenceTransformer(self.rag_model_name)
|
| 19 |
+
self.rag_model.max_seq_length = 4096
|
| 20 |
+
self.rag_model.tokenizer.padding_side = "right"
|
| 21 |
+
|
| 22 |
+
def load_tool_desc_embedding(self, toolbox):
|
| 23 |
+
self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True)
|
| 24 |
+
all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
|
| 25 |
+
md5_value = get_md5(str(all_tools_str))
|
| 26 |
+
print("Computed MD5 for tool embedding:", md5_value)
|
| 27 |
+
|
| 28 |
+
self.tool_embedding_path = os.path.join(
|
| 29 |
+
os.path.dirname(__file__),
|
| 30 |
+
self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if os.path.exists(self.tool_embedding_path):
|
| 34 |
+
try:
|
| 35 |
+
self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu")
|
| 36 |
+
assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \
|
| 37 |
+
"Tool count mismatch with loaded embeddings."
|
| 38 |
+
print("\033[92mLoaded cached tool_desc_embedding.\033[0m")
|
| 39 |
+
return
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"⚠️ Failed loading cached embeddings: {e}")
|
| 42 |
+
self.tool_desc_embedding = None
|
| 43 |
+
|
| 44 |
+
print("\033[93mGenerating new tool_desc_embedding...\033[0m")
|
| 45 |
+
self.tool_desc_embedding = self.rag_model.encode(
|
| 46 |
+
all_tools_str, prompt="", normalize_embeddings=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
torch.save(self.tool_desc_embedding, self.tool_embedding_path)
|
| 50 |
+
print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m")
|
| 51 |
+
|
| 52 |
+
def rag_infer(self, query, top_k=5):
|
| 53 |
+
torch.cuda.empty_cache()
|
| 54 |
+
queries = [query]
|
| 55 |
+
query_embeddings = self.rag_model.encode(
|
| 56 |
+
queries, prompt="", normalize_embeddings=True
|
| 57 |
+
)
|
| 58 |
+
if self.tool_desc_embedding is None:
|
| 59 |
+
raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?")
|
| 60 |
+
|
| 61 |
+
scores = self.rag_model.similarity(
|
| 62 |
+
query_embeddings, self.tool_desc_embedding
|
| 63 |
+
)
|
| 64 |
+
top_k = min(top_k, len(self.tool_name))
|
| 65 |
+
top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
|
| 66 |
+
top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
|
| 67 |
+
return top_k_tool_names
|
src/txagent.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Dict, Optional, List, Union
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
# Configure logging for Hugging Face Spaces
|
| 9 |
+
logging.basicConfig(
|
| 10 |
+
level=logging.INFO,
|
| 11 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 12 |
+
)
|
| 13 |
+
logger = logging.getLogger("TxAgent")
|
| 14 |
+
|
| 15 |
+
class TxAgent:
|
| 16 |
+
def __init__(self,
|
| 17 |
+
model_name: str,
|
| 18 |
+
rag_model_name: str,
|
| 19 |
+
tool_files_dict: Optional[Dict] = None,
|
| 20 |
+
enable_finish: bool = True,
|
| 21 |
+
enable_rag: bool = False,
|
| 22 |
+
force_finish: bool = True,
|
| 23 |
+
enable_checker: bool = True,
|
| 24 |
+
step_rag_num: int = 4,
|
| 25 |
+
seed: Optional[int] = None):
|
| 26 |
+
|
| 27 |
+
# Initialization parameters
|
| 28 |
+
self.model_name = model_name
|
| 29 |
+
self.rag_model_name = rag_model_name
|
| 30 |
+
self.tool_files_dict = tool_files_dict or {}
|
| 31 |
+
self.enable_finish = enable_finish
|
| 32 |
+
self.enable_rag = enable_rag
|
| 33 |
+
self.force_finish = force_finish
|
| 34 |
+
self.enable_checker = enable_checker
|
| 35 |
+
self.step_rag_num = step_rag_num
|
| 36 |
+
self.seed = seed
|
| 37 |
+
|
| 38 |
+
# Device setup
|
| 39 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
|
| 41 |
+
# Models
|
| 42 |
+
self.model = None
|
| 43 |
+
self.tokenizer = None
|
| 44 |
+
self.rag_model = None
|
| 45 |
+
|
| 46 |
+
# Prompts
|
| 47 |
+
self.chat_prompt = "You are a helpful assistant for user chat."
|
| 48 |
+
|
| 49 |
+
logger.info(f"Initialized TxAgent with model: {model_name}")
|
| 50 |
+
|
| 51 |
+
def init_model(self):
|
| 52 |
+
"""Initialize all models and components"""
|
| 53 |
+
try:
|
| 54 |
+
self.load_llm_model()
|
| 55 |
+
if self.enable_rag:
|
| 56 |
+
self.load_rag_model()
|
| 57 |
+
logger.info("Models initialized successfully")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Model initialization failed: {str(e)}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
def load_llm_model(self):
|
| 63 |
+
"""Load the main LLM model"""
|
| 64 |
+
try:
|
| 65 |
+
logger.info(f"Loading LLM model: {self.model_name}")
|
| 66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 67 |
+
self.model_name,
|
| 68 |
+
trust_remote_code=True
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 72 |
+
self.model_name,
|
| 73 |
+
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
|
| 74 |
+
device_map="auto",
|
| 75 |
+
trust_remote_code=True
|
| 76 |
+
)
|
| 77 |
+
logger.info(f"LLM model loaded on {self.device}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"Failed to load LLM model: {str(e)}")
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
def load_rag_model(self):
|
| 83 |
+
"""Load the RAG model"""
|
| 84 |
+
try:
|
| 85 |
+
logger.info(f"Loading RAG model: {self.rag_model_name}")
|
| 86 |
+
self.rag_model = SentenceTransformer(
|
| 87 |
+
self.rag_model_name,
|
| 88 |
+
device=str(self.device)
|
| 89 |
+
)
|
| 90 |
+
logger.info("RAG model loaded successfully")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Failed to load RAG model: {str(e)}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def chat(self, message: str, history: Optional[List[Dict]] = None,
|
| 96 |
+
temperature: float = 0.7, max_new_tokens: int = 512) -> str:
|
| 97 |
+
"""Handle chat conversations"""
|
| 98 |
+
try:
|
| 99 |
+
conversation = []
|
| 100 |
+
|
| 101 |
+
# Initialize with system prompt
|
| 102 |
+
conversation.append({"role": "system", "content": self.chat_prompt})
|
| 103 |
+
|
| 104 |
+
# Add history if provided
|
| 105 |
+
if history:
|
| 106 |
+
for msg in history:
|
| 107 |
+
conversation.append({"role": msg["role"], "content": msg["content"]})
|
| 108 |
+
|
| 109 |
+
# Add current message
|
| 110 |
+
conversation.append({"role": "user", "content": message})
|
| 111 |
+
|
| 112 |
+
# Generate response
|
| 113 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 114 |
+
conversation,
|
| 115 |
+
add_generation_prompt=True,
|
| 116 |
+
return_tensors="pt"
|
| 117 |
+
).to(self.device)
|
| 118 |
+
|
| 119 |
+
generation_config = GenerationConfig(
|
| 120 |
+
max_new_tokens=max_new_tokens,
|
| 121 |
+
temperature=temperature,
|
| 122 |
+
do_sample=True,
|
| 123 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
outputs = self.model.generate(
|
| 127 |
+
inputs,
|
| 128 |
+
generation_config=generation_config
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Decode and clean up response
|
| 132 |
+
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 133 |
+
return response.strip()
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Chat failed: {str(e)}")
|
| 137 |
+
raise RuntimeError(f"Chat failed: {str(e)}")
|
| 138 |
+
|
| 139 |
+
def cleanup(self):
|
| 140 |
+
"""Clean up resources"""
|
| 141 |
+
try:
|
| 142 |
+
if hasattr(self, 'model'):
|
| 143 |
+
del self.model
|
| 144 |
+
if hasattr(self, 'rag_model'):
|
| 145 |
+
del self.rag_model
|
| 146 |
+
torch.cuda.empty_cache()
|
| 147 |
+
logger.info("Resources cleaned up")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Cleanup failed: {str(e)}")
|
| 150 |
+
raise
|
| 151 |
+
|
| 152 |
+
def __del__(self):
|
| 153 |
+
"""Destructor to ensure proper cleanup"""
|
| 154 |
+
self.cleanup()
|
src/utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
import hashlib
|
| 4 |
+
import torch
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_md5(input_str):
|
| 9 |
+
# Create an MD5 hash object
|
| 10 |
+
md5_hash = hashlib.md5()
|
| 11 |
+
md5_hash.update(input_str.encode('utf-8'))
|
| 12 |
+
return md5_hash.hexdigest()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def tool_result_format(function_call_messages):
|
| 16 |
+
current_output = "\n\n<details>\n<summary> <strong>Verified Feedback from Tools</strong>, click to see details:</summary>\n\n"
|
| 17 |
+
for each_message in function_call_messages:
|
| 18 |
+
if each_message['role'] == 'tool':
|
| 19 |
+
try:
|
| 20 |
+
parsed = json.loads(each_message['content'])
|
| 21 |
+
tool_name = parsed.get("tool_name", "Unknown Tool")
|
| 22 |
+
tool_output = parsed.get("content", each_message['content'])
|
| 23 |
+
current_output += f"**🔧 Tool: {tool_name}**\n\n{tool_output}\n\n"
|
| 24 |
+
except Exception:
|
| 25 |
+
current_output += f"{each_message['content']}\n\n"
|
| 26 |
+
current_output += "</details>\n\n\n"
|
| 27 |
+
return current_output
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NoRepeatSentenceProcessor:
|
| 31 |
+
def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
|
| 32 |
+
self.allowed_prefix_length = allowed_prefix_length
|
| 33 |
+
self.forbidden_prefix_dict = {}
|
| 34 |
+
for seq in forbidden_sequences:
|
| 35 |
+
if len(seq) > allowed_prefix_length:
|
| 36 |
+
prefix = tuple(seq[:allowed_prefix_length])
|
| 37 |
+
next_token = seq[allowed_prefix_length]
|
| 38 |
+
self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token)
|
| 39 |
+
|
| 40 |
+
def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
if len(token_ids) >= self.allowed_prefix_length:
|
| 42 |
+
prefix = tuple(token_ids[:self.allowed_prefix_length])
|
| 43 |
+
if prefix in self.forbidden_prefix_dict:
|
| 44 |
+
for token_id in self.forbidden_prefix_dict[prefix]:
|
| 45 |
+
logits[token_id] = -float("inf")
|
| 46 |
+
return logits
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ReasoningTraceChecker:
|
| 50 |
+
def __init__(self, question, conversation, init_index=None):
|
| 51 |
+
self.question = question.lower()
|
| 52 |
+
self.conversation = conversation
|
| 53 |
+
self.existing_thoughts = []
|
| 54 |
+
self.existing_actions = []
|
| 55 |
+
self.new_thoughts = []
|
| 56 |
+
self.new_actions = []
|
| 57 |
+
self.index = init_index if init_index is not None else 1
|
| 58 |
+
|
| 59 |
+
def check_conversation(self):
|
| 60 |
+
info = ''
|
| 61 |
+
current_index = self.index
|
| 62 |
+
for i in range(current_index, len(self.conversation)):
|
| 63 |
+
each = self.conversation[i]
|
| 64 |
+
self.index = i
|
| 65 |
+
if each['role'] == 'assistant':
|
| 66 |
+
thought = each['content']
|
| 67 |
+
actions = each['tool_calls']
|
| 68 |
+
good_status, current_info = self.check_repeat_thought(thought)
|
| 69 |
+
info += current_info
|
| 70 |
+
if not good_status:
|
| 71 |
+
return False, info
|
| 72 |
+
good_status, current_info = self.check_repeat_action(actions)
|
| 73 |
+
info += current_info
|
| 74 |
+
if not good_status:
|
| 75 |
+
return False, info
|
| 76 |
+
return True, info
|
| 77 |
+
|
| 78 |
+
def check_repeat_thought(self, thought):
|
| 79 |
+
if thought in self.existing_thoughts:
|
| 80 |
+
return False, "repeat_thought"
|
| 81 |
+
self.existing_thoughts.append(thought)
|
| 82 |
+
return True, ''
|
| 83 |
+
|
| 84 |
+
def check_repeat_action(self, actions):
|
| 85 |
+
if type(actions) != list:
|
| 86 |
+
actions = json.loads(actions)
|
| 87 |
+
for each_action in actions:
|
| 88 |
+
if 'call_id' in each_action:
|
| 89 |
+
del each_action['call_id']
|
| 90 |
+
each_action = json.dumps(each_action)
|
| 91 |
+
if each_action in self.existing_actions:
|
| 92 |
+
return False, "repeat_action"
|
| 93 |
+
self.existing_actions.append(each_action)
|
| 94 |
+
return True, ''
|
utils.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import hashlib
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
from bson import ObjectId
|
| 8 |
+
import logging
|
| 9 |
+
from config import logger
|
| 10 |
+
# Add to your utils.py
|
| 11 |
+
from fastapi import WebSocket
|
| 12 |
+
import asyncio
|
| 13 |
+
|
| 14 |
+
class NotificationManager:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.active_connections = {}
|
| 17 |
+
self.notification_queue = asyncio.Queue()
|
| 18 |
+
|
| 19 |
+
async def connect(self, websocket: WebSocket, user_id: str):
|
| 20 |
+
await websocket.accept()
|
| 21 |
+
self.active_connections[user_id] = websocket
|
| 22 |
+
|
| 23 |
+
def disconnect(self, user_id: str):
|
| 24 |
+
if user_id in self.active_connections:
|
| 25 |
+
del self.active_connections[user_id]
|
| 26 |
+
|
| 27 |
+
async def broadcast_notification(self, notification: dict):
|
| 28 |
+
"""Broadcast to all connected clients"""
|
| 29 |
+
for connection in self.active_connections.values():
|
| 30 |
+
try:
|
| 31 |
+
await connection.send_json({
|
| 32 |
+
"type": "notification",
|
| 33 |
+
"data": notification
|
| 34 |
+
})
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Error sending notification: {e}")
|
| 37 |
+
|
| 38 |
+
notification_manager = NotificationManager()
|
| 39 |
+
|
| 40 |
+
async def broadcast_notification(notification: dict):
|
| 41 |
+
"""Broadcast notification to relevant users"""
|
| 42 |
+
# Determine recipients based on notification type/priority
|
| 43 |
+
recipients = []
|
| 44 |
+
if notification["priority"] == "high":
|
| 45 |
+
recipients = ["psychiatrist", "emergency_team", "primary_care"]
|
| 46 |
+
else:
|
| 47 |
+
recipients = ["primary_care", "case_manager"]
|
| 48 |
+
|
| 49 |
+
# Add to each recipient's notification queue
|
| 50 |
+
await notification_manager.notification_queue.put({
|
| 51 |
+
"recipients": recipients,
|
| 52 |
+
"notification": notification
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def clean_text_response(text: str) -> str:
|
| 58 |
+
text = re.sub(r'\n\s*\n', '\n\n', text)
|
| 59 |
+
text = re.sub(r'[ ]+', ' ', text)
|
| 60 |
+
return text.replace("**", "").replace("__", "").strip()
|
| 61 |
+
|
| 62 |
+
def extract_section(text: str, heading: str) -> str:
|
| 63 |
+
try:
|
| 64 |
+
pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
|
| 65 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 66 |
+
return match.group(1).strip() if match else ""
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Section extraction failed for heading '{heading}': {e}")
|
| 69 |
+
return ""
|
| 70 |
+
|
| 71 |
+
def structure_medical_response(text: str) -> Dict:
|
| 72 |
+
def extract_improved(text: str, heading: str) -> str:
|
| 73 |
+
patterns = [
|
| 74 |
+
rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
|
| 75 |
+
rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)",
|
| 76 |
+
rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
|
| 77 |
+
rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
|
| 78 |
+
]
|
| 79 |
+
for pattern in patterns:
|
| 80 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 81 |
+
if match:
|
| 82 |
+
content = match.group(1).strip()
|
| 83 |
+
content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
|
| 84 |
+
return content
|
| 85 |
+
return ""
|
| 86 |
+
|
| 87 |
+
text = text.replace('**', '').replace('__', '')
|
| 88 |
+
return {
|
| 89 |
+
"summary": extract_improved(text, "Summary of Patient's Medical History") or
|
| 90 |
+
extract_improved(text, "Summarize the patient's medical history"),
|
| 91 |
+
"risks": extract_improved(text, "Identify Risks or Red Flags") or
|
| 92 |
+
extract_improved(text, "Risks or Red Flags"),
|
| 93 |
+
"missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or
|
| 94 |
+
extract_improved(text, "What the doctor might have missed"),
|
| 95 |
+
"recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
|
| 96 |
+
extract_improved(text, "Suggested Clinical Actions")
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
def serialize_patient(patient: dict) -> dict:
|
| 100 |
+
patient_copy = patient.copy()
|
| 101 |
+
if "_id" in patient_copy:
|
| 102 |
+
patient_copy["_id"] = str(patient_copy["_id"])
|
| 103 |
+
return patient_copy
|
| 104 |
+
|
| 105 |
+
def compute_patient_data_hash(data: dict) -> str:
|
| 106 |
+
serialized = json.dumps(data, sort_keys=True)
|
| 107 |
+
return hashlib.sha256(serialized.encode()).hexdigest()
|
| 108 |
+
|
| 109 |
+
def compute_file_content_hash(file_content: bytes) -> str:
|
| 110 |
+
return hashlib.sha256(file_content).hexdigest()
|
voice.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from fastapi import HTTPException
|
| 3 |
+
from config import logger
|
| 4 |
+
import io
|
| 5 |
+
import speech_recognition as sr
|
| 6 |
+
from gtts import gTTS
|
| 7 |
+
from pydub import AudioSegment
|
| 8 |
+
import base64
|
| 9 |
+
from utils import clean_text_response # Added this import
|
| 10 |
+
|
| 11 |
+
def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
|
| 12 |
+
recognizer = sr.Recognizer()
|
| 13 |
+
try:
|
| 14 |
+
with io.BytesIO(audio_data) as audio_file:
|
| 15 |
+
with sr.AudioFile(audio_file) as source:
|
| 16 |
+
audio = recognizer.record(source)
|
| 17 |
+
text = recognizer.recognize_google(audio, language=language)
|
| 18 |
+
return text
|
| 19 |
+
except sr.UnknownValueError:
|
| 20 |
+
logger.error("Google Speech Recognition could not understand audio")
|
| 21 |
+
raise HTTPException(status_code=400, detail="Could not understand audio")
|
| 22 |
+
except sr.RequestError as e:
|
| 23 |
+
logger.error(f"Could not request results from Google Speech Recognition service; {e}")
|
| 24 |
+
raise HTTPException(status_code=503, detail="Speech recognition service unavailable")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.error(f"Error in speech recognition: {e}")
|
| 27 |
+
raise HTTPException(status_code=500, detail="Error processing speech")
|
| 28 |
+
|
| 29 |
+
def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
|
| 30 |
+
try:
|
| 31 |
+
tts = gTTS(text=text, lang=language, slow=slow)
|
| 32 |
+
mp3_fp = io.BytesIO()
|
| 33 |
+
tts.write_to_fp(mp3_fp)
|
| 34 |
+
mp3_fp.seek(0)
|
| 35 |
+
return mp3_fp.read()
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Error in text-to-speech conversion: {e}")
|
| 38 |
+
raise HTTPException(status_code=500, detail="Error generating speech")
|
| 39 |
+
|
| 40 |
+
def extract_text_from_pdf(pdf_data: bytes) -> str:
|
| 41 |
+
try:
|
| 42 |
+
from PyPDF2 import PdfReader
|
| 43 |
+
pdf_reader = PdfReader(io.BytesIO(pdf_data))
|
| 44 |
+
text = ""
|
| 45 |
+
for page in pdf_reader.pages:
|
| 46 |
+
text += page.extract_text() or ""
|
| 47 |
+
return clean_text_response(text) # Now works with the import
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error extracting text from PDF: {e}")
|
| 50 |
+
raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
|