Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import time | |
| import json | |
| import numpy as np | |
| app = FastAPI(title="EdgeMed Clinical BERT API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Label maps (from your notebook) ββββββββββββββββββββββββββββββββββββββββββ | |
| id2label = {0: "ESI_1", 1: "ESI_2", 2: "ESI_3", 3: "ESI_4", 4: "ESI_5"} | |
| label2id = {v: k for k, v in id2label.items()} | |
| ESI_SLA = {"ESI_1": 2, "ESI_2": 10, "ESI_3": 30, "ESI_4": 60, "ESI_5": 120} | |
| ESI_LABEL = {"ESI_1": "Resuscitation", "ESI_2": "Emergent", | |
| "ESI_3": "Urgent", "ESI_4": "Less Urgent", "ESI_5": "Non-Urgent"} | |
| # ββ CAG keyword lookup (exact from your notebook) ββββββββββββββββββββββββββββ | |
| CAG_RULES = { | |
| "ESI_1": ["cardiac arrest","not breathing","no pulse","unresponsive", | |
| "unconscious","active seizure","anaphylaxis","major trauma", | |
| "respiratory arrest","hemorrhagic shock","arrest","cpr", | |
| "resus","apnea","shock","code"], | |
| "ESI_2": ["chest pain","acute stroke","stroke","altered mental status", | |
| "severe pain","overdose","sepsis","hypertensive emergency", | |
| "myocardial infarction","difficulty breathing", | |
| "shortness of breath","loss of consciousness","syncope", | |
| "fainting","high fever","sob","dyspnea","loc","seizure", | |
| "convulsion","palpitation","hypotension","ams","cp"], | |
| "ESI_3": ["moderate pain","fever","fracture","vomiting","dizziness", | |
| "weakness","wound","laceration","burn","abdominal pain", | |
| "back pain","headache","swelling","infection","urinary", | |
| "bleeding","trauma","injury","pain"], | |
| "ESI_4": ["mild pain","rash","sore throat","ear pain","eye pain", | |
| "minor","sprain","cough","cold","mild","ocular"], | |
| "ESI_5": ["prescription refill","routine","paperwork", | |
| "immunization","administrative","certificate"], | |
| } | |
| PRIORITY = ["ESI_1", "ESI_2", "ESI_3", "ESI_4", "ESI_5"] | |
| # Build flat lookup | |
| CAG_LOOKUP = {} | |
| for esi, keywords in CAG_RULES.items(): | |
| for kw in keywords: | |
| CAG_LOOKUP[kw] = esi | |
| def cag_classify(text: str): | |
| t = text.lower() | |
| matched_esi, matched_kw = None, None | |
| for kw, esi in CAG_LOOKUP.items(): | |
| if kw in t: | |
| if matched_esi is None or PRIORITY.index(esi) < PRIORITY.index(matched_esi): | |
| matched_esi = esi | |
| matched_kw = kw | |
| return matched_esi, matched_kw | |
| # ββ Keyword β specialty map (from your notebook) βββββββββββββββββββββββββββββ | |
| KEYWORD_SPECIALTY = { | |
| "cardiac": "Cardiology", "chest": "Cardiology", | |
| "heart": "Cardiology", "neuro": "Neurology", | |
| "stroke": "Neurology", "seizure": "Neurology", | |
| "head": "Neurology", "fracture": "Orthopedic", | |
| "bone": "Orthopedic", "joint": "Orthopedic", | |
| "abdom": "General Surgery","bowel": "Gastroenterology", | |
| "liver": "Gastroenterology","breath": "Pulmonology", | |
| "lung": "Pulmonology", "psych": "Psychiatry", | |
| "mental": "Psychiatry", "eye": "Ophthalmology", | |
| "ocular": "Ophthalmology", "ear": "ENT", | |
| "throat": "ENT", "urin": "Urology", | |
| "kidney": "Nephrology", "renal": "Nephrology", | |
| "burn": "General Surgery","wound": "General Surgery", | |
| } | |
| ESI_DEFAULT_SPECIALTY = { | |
| "ESI_1": "Emergency Medicine", "ESI_2": "Emergency Medicine", | |
| "ESI_3": "General Surgery", "ESI_4": "General Surgery", | |
| "ESI_5": "General Surgery", | |
| } | |
| def detect_specialty(symptom_text: str, esi_level: str) -> str: | |
| t = symptom_text.lower() | |
| for kw, spec in KEYWORD_SPECIALTY.items(): | |
| if kw in t: | |
| return spec | |
| return ESI_DEFAULT_SPECIALTY.get(esi_level, "Emergency Medicine") | |
| # ββ Load BERT model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading Mahdiya/edgemed-clinical-bert ...") | |
| tokenizer = AutoTokenizer.from_pretrained("Mahdiya/edgemed-clinical-bert") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "Mahdiya/edgemed-clinical-bert") | |
| model.eval() | |
| device = "cpu" # CPU Basic Space β no GPU available | |
| model.to(device) | |
| print(f"β Model loaded on {device}") | |
| def bert_classify(text: str): | |
| enc = tokenizer( | |
| text[:400], return_tensors="pt", | |
| max_length=128, truncation=True, padding="max_length" | |
| ).to(device) | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| logits = model(**enc).logits | |
| latency_ms = round((time.time() - t0) * 1000, 1) | |
| probs = torch.softmax(logits, dim=-1)[0].cpu().tolist() | |
| pred_id = int(torch.argmax(logits, dim=-1).item()) | |
| pred_esi = id2label[pred_id] | |
| conf = round(probs[pred_id], 4) | |
| all_probs = {id2label[i]: round(p, 4) for i, p in enumerate(probs)} | |
| return pred_esi, conf, latency_ms, all_probs | |
| # ββ Hospital data (200 hospitals, 5 zones β from your notebook seed=42) βββββββ | |
| np.random.seed(42) | |
| SPECIALTIES_ALL = [ | |
| "Cardiology","Neurology","Orthopedic","General Surgery", | |
| "Emergency Medicine","Gastroenterology","Pulmonology", | |
| "Nephrology","Psychiatry","Ophthalmology","ENT","Urology", | |
| "Oncology","Dermatology","Pediatrics","Gynecology", | |
| "Radiology","Anesthesiology","Hematology","Rheumatology" | |
| ] | |
| ZONES = ["Zone-A","Zone-B","Zone-C","Zone-D","Zone-E"] | |
| HOSPITALS = [] | |
| for i in range(200): | |
| zone = ZONES[i // 40] | |
| n_specs = int(np.random.randint(3, 7)) | |
| specs = list(np.random.choice(SPECIALTIES_ALL, n_specs, replace=False)) | |
| HOSPITALS.append({ | |
| "hospital_id": f"H{str(i).zfill(3)}", | |
| "name": f"{zone.replace('Zone-','').strip()} Medical Center {i%40+1}", | |
| "zone": zone, | |
| "specialties": specs, | |
| "response_time": round(float(np.random.uniform(1, 30)), 1), | |
| "quality_score": round(float(np.random.uniform(0.5, 1.0)), 2), | |
| "current_load": round(float(np.random.uniform(0.1, 0.9)), 2), | |
| "availability": bool(np.random.random() > 0.2), | |
| }) | |
| def routing_score(h: dict, alpha: float) -> float: | |
| """Exact formula from your notebook.""" | |
| speed = 1.0 - (h["response_time"] / 30.0) | |
| quality = h["quality_score"] | |
| load = h["current_load"] * 0.3 | |
| return round((alpha * speed + (1 - alpha) * quality) * (1 - load), 4) | |
| def get_top_hospitals(specialty: str, zone: str, alpha: float, | |
| esi: str, top_n: int = 10) -> list: | |
| is_emergency = esi in ("ESI_1", "ESI_2") | |
| results = [] | |
| for h in HOSPITALS: | |
| if not h["availability"]: | |
| continue | |
| if h["current_load"] > 0.85: | |
| continue | |
| spec_match = any(specialty.lower() in s.lower() for s in h["specialties"]) | |
| zone_match = h["zone"] == zone | |
| if is_emergency: | |
| # Emergency β any available hospital with any specialty | |
| eff_alpha = 1.0 # pure speed | |
| score = routing_score(h, eff_alpha) | |
| results.append({**h, "score": score, | |
| "zone_match": zone_match, | |
| "spec_match": spec_match, | |
| "cross_zone": not zone_match}) | |
| else: | |
| if spec_match: | |
| score = routing_score(h, alpha) | |
| results.append({**h, "score": score, | |
| "zone_match": zone_match, | |
| "spec_match": spec_match, | |
| "cross_zone": not zone_match}) | |
| # Sort: zone-local first, then by score | |
| if is_emergency: | |
| results.sort(key=lambda x: x["response_time"]) | |
| else: | |
| results.sort(key=lambda x: (-int(x["zone_match"]), -x["score"])) | |
| return results[:top_n] | |
| # ββ Request / Response models βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TriageRequest(BaseModel): | |
| symptom_text: str | |
| zone: str | |
| alpha: float = 0.5 | |
| class RouteRequest(BaseModel): | |
| symptom_text: str | |
| zone: str | |
| alpha: float | |
| esi_level: str # already determined (from triage step) | |
| specialty: str # already determined | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "EdgeMed API running", | |
| "model": "Mahdiya/edgemed-clinical-bert", | |
| "device": device} | |
| def triage(req: TriageRequest): | |
| """ | |
| Full triage pipeline: | |
| 1. CAG keyword check | |
| 2. If no CAG hit β BERT inference (ESI 3-5) | |
| Returns ESI level, confidence, method used, latency. | |
| """ | |
| t_total = time.time() | |
| # Step 1: CAG | |
| cag_esi, cag_kw = cag_classify(req.symptom_text) | |
| if cag_esi in ("ESI_1", "ESI_2"): | |
| # Bypass BERT β critical keyword found | |
| specialty = detect_specialty(req.symptom_text, cag_esi) | |
| return { | |
| "esi_level": cag_esi, | |
| "esi_label": ESI_LABEL[cag_esi], | |
| "sla_minutes": ESI_SLA[cag_esi], | |
| "confidence": 1.0, | |
| "method": "CAG_BYPASS", | |
| "cag_keyword": cag_kw, | |
| "specialty": specialty, | |
| "bert_probs": None, | |
| "latency_ms": round((time.time() - t_total) * 1000, 1), | |
| } | |
| # Step 2: BERT inference | |
| bert_esi, conf, bert_latency, all_probs = bert_classify(req.symptom_text) | |
| # CAG may have a lower-priority hint (ESI 3-5) β use whichever is more urgent | |
| final_esi = bert_esi | |
| method = "BERT" | |
| if cag_esi and PRIORITY.index(cag_esi) < PRIORITY.index(bert_esi): | |
| final_esi = cag_esi | |
| method = "CAG+BERT" | |
| specialty = detect_specialty(req.symptom_text, final_esi) | |
| return { | |
| "esi_level": final_esi, | |
| "esi_label": ESI_LABEL[final_esi], | |
| "sla_minutes": ESI_SLA[final_esi], | |
| "confidence": conf, | |
| "method": method, | |
| "cag_keyword": cag_kw, | |
| "specialty": specialty, | |
| "bert_probs": all_probs, | |
| "latency_ms": round((time.time() - t_total) * 1000, 1), | |
| } | |
| def route(req: RouteRequest): | |
| """ | |
| KAG routing: given ESI + specialty + zone + alpha, | |
| returns top 10 hospitals ranked by routing score. | |
| """ | |
| hospitals = get_top_hospitals( | |
| specialty=req.specialty, | |
| zone=req.zone, | |
| alpha=req.alpha, | |
| esi=req.esi_level, | |
| top_n=10, | |
| ) | |
| return { | |
| "zone": req.zone, | |
| "specialty": req.specialty, | |
| "esi_level": req.esi_level, | |
| "alpha": req.alpha, | |
| "hospitals": hospitals, | |
| "total": len(hospitals), | |
| } | |
| def zones(): | |
| counts = {} | |
| for z in ZONES: | |
| avail = sum(1 for h in HOSPITALS if h["zone"] == z and h["availability"]) | |
| counts[z] = {"total": 40, "available": avail} | |
| return {"zones": ZONES, "counts": counts} | |