Spaces:
Sleeping
Sleeping
File size: 23,227 Bytes
79bcb1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 |
import torch
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
pipeline
)
import re
import os
import json
from typing import Dict, List, Tuple, Any
class SymptomExtractor:
"""Model for extracting symptoms from patient descriptions using BioBERT."""
def __init__(self, model_name="dmis-lab/biobert-v1.1", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Symptom Extractor model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name).to(self.device)
self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1)
print("Symptom Extractor model loaded successfully.")
def extract_symptoms(self, text: str) -> Dict[str, Any]:
"""Extract symptoms from the input text."""
results = self.nlp(text)
# Process the NER results to group related tokens
symptoms = []
current_symptom = None
for entity in results:
if entity["entity"].startswith("B-"): # Beginning of a symptom
if current_symptom:
symptoms.append(current_symptom)
current_symptom = {
"text": entity["word"],
"start": entity["start"],
"end": entity["end"],
"score": entity["score"]
}
elif entity["entity"].startswith("I-") and current_symptom: # Inside a symptom
current_symptom["text"] += " " + entity["word"].replace("##", "")
current_symptom["end"] = entity["end"]
current_symptom["score"] = (current_symptom["score"] + entity["score"]) / 2
if current_symptom:
symptoms.append(current_symptom)
# Extract duration information
duration_patterns = [
r"(\d+)\s*(day|days|week|weeks|month|months|year|years)",
r"since\s+(\w+)",
r"for\s+(\w+)"
]
duration_info = []
for pattern in duration_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
duration_info.append({
"text": match.group(0),
"start": match.start(),
"end": match.end()
})
return {
"symptoms": symptoms,
"duration": duration_info
}
class RiskClassifier:
"""Model for classifying patient risk level using PubMedBERT."""
def __init__(self, model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Risk Classifier model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=3 # Low, Medium, High
).to(self.device)
self.id2label = {0: "Low", 1: "Medium", 2: "High"}
print("Risk Classifier model loaded successfully.")
def classify_risk(self, text: str) -> Dict[str, Any]:
"""Classify the risk level based on the input text."""
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0].cpu().numpy()
model_prediction = torch.argmax(logits, dim=1).item()
# 由于模型没有经过微调,我们添加基于规则的后处理来调整风险级别
# 检查文本中是否存在高风险关键词
high_risk_keywords = [
"severe", "extreme", "intense", "unbearable", "emergency",
"chest pain", "difficulty breathing", "can't breathe",
"losing consciousness", "fainted", "seizure", "stroke", "heart attack",
"allergic reaction", "bleeding heavily", "blood", "poisoning",
"overdose", "suicide", "self-harm", "hallucinations"
]
medium_risk_keywords = [
"worsening", "spreading", "persistent", "chronic", "recurring",
"infection", "fever", "swelling", "rash", "pain", "vomiting",
"diarrhea", "dizzy", "headache", "concerning", "worried",
"weeks", "days", "increasing", "progressing"
]
low_risk_keywords = [
"mild", "slight", "minor", "occasional", "intermittent",
"improving", "better", "sometimes", "rarely", "manageable"
]
text_lower = text.lower()
# 计算匹配的关键词数量
high_risk_matches = sum(keyword in text_lower for keyword in high_risk_keywords)
medium_risk_matches = sum(keyword in text_lower for keyword in medium_risk_keywords)
low_risk_matches = sum(keyword in text_lower for keyword in low_risk_keywords)
# 根据关键词匹配调整风险预测
adjusted_prediction = model_prediction
if high_risk_matches >= 2:
adjusted_prediction = 2 # High risk
elif high_risk_matches == 1 and medium_risk_matches >= 2:
adjusted_prediction = 2 # High risk
elif medium_risk_matches >= 3:
adjusted_prediction = 1 # Medium risk
elif medium_risk_matches >= 1 and low_risk_matches <= 1:
adjusted_prediction = 1 # Medium risk
elif low_risk_matches >= 2 and high_risk_matches == 0:
adjusted_prediction = 0 # Low risk
# 如果文本很长(详细描述),可能表明情况更复杂,风险更高
if len(text.split()) > 40 and adjusted_prediction == 0:
adjusted_prediction = 1 # 升级到Medium风险
# 对调整后的概率进行修正
adjusted_probabilities = probabilities.copy()
# 增强对应风险级别的概率
adjusted_probabilities[adjusted_prediction] = max(0.6, adjusted_probabilities[adjusted_prediction])
# 规范化概率使其总和为1
adjusted_probabilities = adjusted_probabilities / adjusted_probabilities.sum()
return {
"risk_level": self.id2label[adjusted_prediction],
"confidence": float(adjusted_probabilities[adjusted_prediction]),
"all_probabilities": {
self.id2label[i]: float(prob)
for i, prob in enumerate(adjusted_probabilities)
},
"original_prediction": self.id2label[model_prediction]
}
class RecommendationGenerator:
"""Model for generating medical recommendations using fine-tuned t5-small."""
def __init__(self, model_path="t5-small", device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading Recommendation Generator model on {self.device}...")
# 检查常见的微调模型路径
possible_local_paths = [
"./finetuned_t5-small", # 添加用户指定的微调模型路径
"./t5-small-medical-recommendation",
"./models/t5-small-medical-recommendation",
"./fine_tuned_models/t5-small",
"./output",
"./fine_tuning_output"
]
# 检查是否为路径或模型标识符
model_exists = False
for path in possible_local_paths:
if os.path.exists(path):
model_path = path
model_exists = True
print(f"Found fine-tuned model at: {model_path}")
break
if not model_exists and model_path == "t5-small-medical-recommendation":
print("Fine-tuned model not found locally. Falling back to base t5-small...")
model_path = "t5-small"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
print(f"Recommendation Generator model '{model_path}' loaded successfully.")
except Exception as e:
print(f"Error loading model from {model_path}: {str(e)}")
print("Falling back to base t5-small model...")
self.tokenizer = AutoTokenizer.from_pretrained("t5-small")
self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(self.device)
print("Base t5-small model loaded successfully as fallback.")
# 科室映射 - 症状关键词到科室的映射
self.symptom_to_department = {
"headache": "Neurology",
"dizziness": "Neurology",
"confusion": "Neurology",
"memory": "Neurology",
"numbness": "Neurology",
"tingling": "Neurology",
"seizure": "Neurology",
"nerve": "Neurology",
"chest pain": "Cardiology",
"heart": "Cardiology",
"palpitation": "Cardiology",
"arrhythmia": "Cardiology",
"high blood pressure": "Cardiology",
"hypertension": "Cardiology",
"heart attack": "Cardiology",
"cardiovascular": "Cardiology",
"cough": "Pulmonology",
"breathing": "Pulmonology",
"shortness": "Pulmonology",
"lung": "Pulmonology",
"respiratory": "Pulmonology",
"asthma": "Pulmonology",
"pneumonia": "Pulmonology",
"copd": "Pulmonology",
"stomach": "Gastroenterology",
"abdomen": "Gastroenterology",
"nausea": "Gastroenterology",
"vomit": "Gastroenterology",
"diarrhea": "Gastroenterology",
"constipation": "Gastroenterology",
"heartburn": "Gastroenterology",
"liver": "Gastroenterology",
"digestive": "Gastroenterology",
"joint": "Orthopedics",
"bone": "Orthopedics",
"muscle": "Orthopedics",
"pain": "Orthopedics",
"back": "Orthopedics",
"arthritis": "Orthopedics",
"fracture": "Orthopedics",
"sprain": "Orthopedics",
"rash": "Dermatology",
"skin": "Dermatology",
"itching": "Dermatology",
"itch": "Dermatology",
"acne": "Dermatology",
"eczema": "Dermatology",
"psoriasis": "Dermatology",
"fever": "General Medicine / Primary Care",
"infection": "General Medicine / Primary Care",
"sore throat": "General Medicine / Primary Care",
"flu": "General Medicine / Primary Care",
"cold": "General Medicine / Primary Care",
"fatigue": "General Medicine / Primary Care",
"pregnancy": "Obstetrics / Gynecology",
"menstruation": "Obstetrics / Gynecology",
"period": "Obstetrics / Gynecology",
"vaginal": "Obstetrics / Gynecology",
"menopause": "Obstetrics / Gynecology",
"depression": "Psychiatry",
"anxiety": "Psychiatry",
"mood": "Psychiatry",
"stress": "Psychiatry",
"sleep": "Psychiatry",
"insomnia": "Psychiatry",
"mental": "Psychiatry",
"ear": "Otolaryngology (ENT)",
"nose": "Otolaryngology (ENT)",
"throat": "Otolaryngology (ENT)",
"hearing": "Otolaryngology (ENT)",
"sinus": "Otolaryngology (ENT)",
"eye": "Ophthalmology",
"vision": "Ophthalmology",
"blindness": "Ophthalmology",
"blurry": "Ophthalmology",
"urination": "Urology",
"kidney": "Urology",
"bladder": "Urology",
"urine": "Urology",
"prostate": "Urology"
}
# 自我护理建议
self.self_care_by_risk = {
"Low": [
"Ensure you're getting adequate rest",
"Stay hydrated by drinking plenty of water",
"Monitor your symptoms and note any changes",
"Consider over-the-counter medications appropriate for your symptoms",
"Maintain a balanced diet to support your immune system",
"Try gentle exercises if appropriate for your condition",
"Avoid activities that worsen your symptoms",
"Keep track of any patterns in your symptoms"
],
"Medium": [
"Rest and avoid strenuous activities",
"Stay hydrated and maintain proper nutrition",
"Take your temperature and other vital signs if possible",
"Write down any changes in symptoms and when they occur",
"Have someone stay with you if your symptoms are concerning",
"Prepare a list of your symptoms and medications for your doctor",
"Avoid self-medicating beyond basic over-the-counter remedies",
"Consider arranging transportation to your medical appointment"
],
"High": [
"Don't wait - seek medical attention immediately",
"Have someone drive you to the emergency room if safe to do so",
"Call emergency services if symptoms are severe",
"Bring a list of your current medications if possible",
"Follow any first aid protocols appropriate for your symptoms",
"Don't eat or drink anything if you might need surgery",
"Take prescribed emergency medications if applicable (like an inhaler for asthma)",
"Try to remain calm and focused on getting help"
]
}
def _extract_departments_from_symptoms(self, symptoms_text: str) -> List[str]:
"""
从症状描述中提取可能的相关科室
Args:
symptoms_text: 症状描述文本
Returns:
科室名称列表
"""
departments = set()
symptoms_lower = symptoms_text.lower()
# 通过关键词匹配寻找相关科室
for keyword, department in self.symptom_to_department.items():
if keyword in symptoms_lower:
departments.add(department)
# 如果没有找到匹配的科室,返回常规医疗科室
if not departments:
departments.add("General Medicine / Primary Care")
return list(departments)
def _get_self_care_suggestions(self, risk_level: str) -> List[str]:
"""
根据风险级别获取自我护理建议
Args:
risk_level: 风险级别 (Low, Medium, High)
Returns:
自我护理建议列表
"""
# 确保风险级别有效
if risk_level not in self.self_care_by_risk:
risk_level = "Medium" # 默认返回中等风险的建议
# 返回为该风险级别准备的建议
suggestions = self.self_care_by_risk[risk_level]
# 随机选择5项建议,避免每次返回完全相同的内容
import random
if len(suggestions) > 5:
selected = random.sample(suggestions, 5)
else:
selected = suggestions
return selected
def _format_structured_recommendation(self, medical_advice: str, departments: List[str], self_care: List[str], risk_level: str) -> str:
"""
格式化结构化建议为文本格式
Args:
medical_advice: 主要医疗建议
departments: 建议科室列表
self_care: 自我护理建议列表
risk_level: 风险级别
Returns:
格式化后的完整建议文本
"""
# 初始化建议文本
recommendation = ""
# 添加主要医疗建议
recommendation += medical_advice.strip() + "\n\n"
# 添加建议科室部分
recommendation += f"RECOMMENDED DEPARTMENTS: Based on your symptoms, consider consulting the following departments: {', '.join(departments)}.\n\n"
# 添加自我护理部分
recommendation += f"SELF-CARE SUGGESTIONS: While {risk_level.lower()} risk level requires {'immediate attention' if risk_level == 'High' else 'medical care soon' if risk_level == 'Medium' else 'monitoring'}, you can also:\n"
for suggestion in self_care:
recommendation += f"- {suggestion}\n"
return recommendation
def generate_recommendation(self,
symptoms: str,
risk_level: str,
max_length: int = 150) -> Dict[str, Any]:
"""
Generate a comprehensive medical recommendation based on symptoms and risk level.
Args:
symptoms: Symptom description text
risk_level: Risk level (Low, Medium, High)
max_length: Maximum length for generated text
Returns:
Dictionary containing structured recommendation including medical advice,
department suggestions, and self-care tips
"""
# 创建输入提示
input_text = f"Symptoms: {symptoms} Risk: {risk_level}"
# 通过模型生成主要医疗建议
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
# 解码生成的医疗建议
medical_advice = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# 从症状提取建议科室
departments = self._extract_departments_from_symptoms(symptoms)
# 如果是高风险,添加急诊科
if risk_level == "High" and "Emergency Medicine" not in departments:
departments.insert(0, "Emergency Medicine")
# 获取自我护理建议
self_care_suggestions = self._get_self_care_suggestions(risk_level)
# 创建完整的结构化建议
structured_recommendation = {
"medical_advice": medical_advice,
"departments": departments,
"self_care": self_care_suggestions
}
# 格式化为文本格式的完整建议
formatted_text = self._format_structured_recommendation(
medical_advice,
departments,
self_care_suggestions,
risk_level
)
return {
"text": formatted_text,
"structured": structured_recommendation
}
class MedicalConsultationPipeline:
"""Complete pipeline for medical consultation."""
def __init__(self,
symptom_model="dmis-lab/biobert-v1.1",
risk_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
recommendation_model="t5-small",
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initializing Medical Consultation Pipeline on {self.device}...")
self.symptom_extractor = SymptomExtractor(model_name=symptom_model, device=self.device)
self.risk_classifier = RiskClassifier(model_name=risk_model, device=self.device)
self.recommendation_generator = RecommendationGenerator(model_path=recommendation_model, device=self.device)
print("Medical Consultation Pipeline initialized successfully.")
def process(self, text: str) -> Dict[str, Any]:
"""Process the patient description through the complete pipeline."""
# Step 1: Extract symptoms
extraction_results = self.symptom_extractor.extract_symptoms(text)
# Step 2: Classify risk
risk_results = self.risk_classifier.classify_risk(text)
# Create a summary of the symptoms for the recommendation model
symptoms_summary = ", ".join([symptom["text"] for symptom in extraction_results["symptoms"]])
if not symptoms_summary:
symptoms_summary = text # Use original text if no symptoms found
# Step 3: Generate recommendation
recommendation_result = self.recommendation_generator.generate_recommendation(
symptoms=symptoms_summary,
risk_level=risk_results["risk_level"]
)
return {
"extraction": extraction_results,
"risk": risk_results,
"recommendation": recommendation_result["text"],
"structured_recommendation": recommendation_result["structured"],
"input_text": text
}
# Example usage
if __name__ == "__main__":
# This is just a test code that won't run in the Streamlit app
pipeline = MedicalConsultationPipeline()
sample_text = "I've been experiencing severe headaches and dizziness for about 2 weeks. Sometimes I also feel nauseous."
result = pipeline.process(sample_text)
print("Extracted symptoms:", [s["text"] for s in result["extraction"]["symptoms"]])
print("Duration info:", [d["text"] for d in result["extraction"]["duration"]])
print("Risk level:", result["risk"]["risk_level"], f"(Confidence: {result['risk']['confidence']:.2f})")
print("Recommendation:", result["recommendation"]) |