| """ |
| AMR-Guard — Gradio Interface (ZeroGPU compatible) |
| Infection Lifecycle Orchestrator · Multi-Agent Clinical Decision Support |
| """ |
|
|
| import json |
| import logging |
| import os |
| import subprocess |
| import sys |
| import traceback |
| from io import BytesIO |
| from pathlib import Path |
|
|
| PROJECT_ROOT = Path(__file__).parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| stream=sys.stdout, |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| force=True, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
| try: |
| from huggingface_hub import HfFolder as _check |
| except ImportError: |
| import huggingface_hub as _hfh |
|
|
| class _HfFolder: |
| @staticmethod |
| def get_token(): |
| return os.environ.get("HF_TOKEN") or _hfh.get_token() |
|
|
| @staticmethod |
| def save_token(token: str) -> None: |
| pass |
|
|
| @staticmethod |
| def delete_token() -> None: |
| pass |
|
|
| _hfh.HfFolder = _HfFolder |
|
|
| |
| _DB_PATH = PROJECT_ROOT / os.getenv("MEDIC_DATA_DIR", "data") / "amr_guard.db" |
| if os.environ.get("SPACE_ID") and not _DB_PATH.exists(): |
| subprocess.run([sys.executable, str(PROJECT_ROOT / "setup_demo.py")], check=False) |
|
|
| import gradio as gr |
| import pandas as pd |
|
|
| |
| |
| |
| |
| try: |
| import gradio.utils as _gr_utils |
| _orig_get_type = getattr(_gr_utils, "get_type", None) |
| if _orig_get_type: |
| def _safe_get_type(schema, *a, **kw): |
| if not isinstance(schema, dict): |
| return "other" |
| return _orig_get_type(schema, *a, **kw) |
| _gr_utils.get_type = _safe_get_type |
| except Exception: |
| pass |
| try: |
| import gradio.route_utils as _gr_ru |
| for _fn_name in ("get_type", "_json_schema_to_python_type", "json_schema_to_python_type"): |
| _fn = getattr(_gr_ru, _fn_name, None) |
| if _fn: |
| def _safe_fn(schema, *a, _f=_fn, **kw): |
| if not isinstance(schema, dict): |
| return "other" |
| return _f(schema, *a, **kw) |
| setattr(_gr_ru, _fn_name, _safe_fn) |
| except Exception: |
| pass |
|
|
| from src.config import get_settings |
| from src.form_config import CREATININE_PROMINENT_SITES, SITE_SPECIFIC_FIELDS, SUSPECTED_SOURCE_OPTIONS |
| from src.loader import run_inference |
|
|
| |
| |
| |
| |
| |
| |
| |
| if os.environ.get("SPACE_ID"): |
| try: |
| import spaces as _spaces_ui |
|
|
| @_spaces_ui.GPU(duration=200) |
| def _run_pipeline_gpu(patient_data: dict, labs_raw_text): |
| from src.graph import run_pipeline |
| return run_pipeline(patient_data, labs_raw_text) |
| except ImportError: |
| def _run_pipeline_gpu(patient_data: dict, labs_raw_text): |
| from src.graph import run_pipeline |
| return run_pipeline(patient_data, labs_raw_text) |
| else: |
| def _run_pipeline_gpu(patient_data: dict, labs_raw_text): |
| from src.graph import run_pipeline |
| return run_pipeline(patient_data, labs_raw_text) |
|
|
| from src.tools import ( |
| calculate_mic_trend, |
| get_empirical_therapy_guidance, |
| get_most_effective_antibiotics, |
| interpret_mic_value, |
| screen_antibiotic_safety, |
| search_clinical_guidelines, |
| ) |
|
|
| |
|
|
| CSS = """ |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); |
| body, .gradio-container { font-family: 'Inter', sans-serif !important; } |
| |
| /* ── Banner ── */ |
| .med-banner { |
| background: linear-gradient(135deg, #020d1f 0%, #0b2545 55%, #102f62 100%); |
| padding: 24px 32px; border-radius: 14px; margin-bottom: 22px; |
| border: 1px solid #1e4a80; |
| box-shadow: 0 0 48px rgba(26,74,138,0.45), inset 0 1px 0 rgba(158,196,240,0.12); |
| } |
| .med-banner h1 { |
| color: #ffffff; font-size: 1.95rem; font-weight: 700; margin: 0; |
| text-shadow: 0 0 24px rgba(96,196,255,0.45); |
| } |
| .med-banner p { color: #7eb8e8; font-size: 0.95rem; margin: 5px 0 0; } |
| |
| /* ── Section titles ── */ |
| .section-title { |
| font-size: 0.8rem; font-weight: 700; color: #60b4ff; |
| border-bottom: 1px solid #1e3f72; padding-bottom: 6px; margin: 18px 0 13px; |
| text-transform: uppercase; letter-spacing: 0.1em; |
| } |
| |
| /* ── Stat cards ── */ |
| .stat-cards { |
| display: grid; grid-template-columns: repeat(4, 1fr); gap: 16px; margin-bottom: 22px; |
| } |
| .stat-card { |
| background: linear-gradient(160deg, #0b1e3d 0%, #0e2a56 100%); |
| border: 1px solid #1e4a80; border-top: 3px solid #3b82f6; |
| border-radius: 11px; padding: 18px 20px; text-align: center; |
| box-shadow: 0 4px 18px rgba(0,0,0,0.35); |
| } |
| .stat-card .label { |
| color: #7eaadb; font-size: 0.78rem; font-weight: 600; |
| text-transform: uppercase; letter-spacing: 0.05em; |
| } |
| .stat-card .value { color: #60c8ff; font-size: 1.65rem; font-weight: 700; margin-top: 5px; } |
| .stat-card .sub { color: #a8cce8; font-size: 0.75rem; margin-top: 3px; } |
| |
| /* ── Agent steps ── */ |
| .agent-step { |
| background: linear-gradient(135deg, #091a36 0%, #0d2450 100%); |
| border: 1px solid #1e4278; border-left: 4px solid #3b82f6; |
| border-radius: 8px; padding: 14px 16px; margin-bottom: 10px; |
| } |
| .agent-step .num { color: #60b4ff; font-weight: 700; font-size: 0.82rem; letter-spacing: 0.04em; } |
| .agent-step .name { color: #dceeff; font-weight: 600; } |
| .agent-step .desc { color: #8ab4d8; font-size: 0.85rem; margin-top: 4px; } |
| |
| /* ── Status badges — dark backgrounds, high-contrast text ── */ |
| .badge-high { |
| background: #1e0707; border-left: 4px solid #dc2626; color: #fca5a5; |
| padding: 10px 14px; border-radius: 7px; margin-bottom: 6px; |
| } |
| .badge-moderate { |
| background: #1c1200; border-left: 4px solid #d97706; color: #fcd34d; |
| padding: 10px 14px; border-radius: 7px; margin-bottom: 6px; |
| } |
| .badge-low { |
| background: #021a0e; border-left: 4px solid #16a34a; color: #86efac; |
| padding: 10px 14px; border-radius: 7px; margin-bottom: 6px; |
| } |
| .badge-info { |
| background: #071428; border-left: 4px solid #2563eb; color: #93c5fd; |
| padding: 10px 14px; border-radius: 7px; margin-bottom: 6px; |
| } |
| |
| /* ── Prescription card ── */ |
| .rx-card { |
| background: linear-gradient(145deg, #081730 0%, #0c2248 100%); |
| border: 1px solid #1e4a80; border-radius: 12px; |
| padding: 24px 26px; font-size: 0.9rem; line-height: 1.75; color: #cce3ff; |
| box-shadow: 0 6px 28px rgba(0,0,0,0.45), 0 0 0 1px rgba(59,130,246,0.18); |
| } |
| .rx-card .rx-symbol { |
| font-size: 2.2rem; color: #60c8ff; font-weight: 700; |
| text-shadow: 0 0 14px rgba(96,200,255,0.55); |
| } |
| .rx-card .rx-drug { font-size: 1.25rem; font-weight: 700; color: #ffffff; } |
| .rx-card strong { color: #a8d4ff; } |
| .rx-card ul { color: #cce3ff; } |
| |
| /* ── Badge child elements inherit color ── */ |
| .badge-high strong, .badge-high em, .badge-high span { color: inherit; } |
| .badge-moderate strong,.badge-moderate em,.badge-moderate span { color: inherit; } |
| .badge-low strong, .badge-low em, .badge-low span { color: inherit; } |
| .badge-info strong, .badge-info em, .badge-info span { color: inherit; } |
| |
| /* ── Disclaimer ── */ |
| .disclaimer { |
| background: #150e00; border: 1px solid #78450e; border-radius: 8px; |
| padding: 12px 16px; font-size: 0.78rem; color: #fbbf24; margin-top: 20px; |
| } |
| """ |
|
|
| BANNER_HTML = """ |
| <div class="med-banner"> |
| <div> |
| <h1>⚕ AMR-Guard</h1> |
| <p>Infection Lifecycle Orchestrator · Multi-Agent Clinical Decision Support</p> |
| </div> |
| </div> |
| """ |
|
|
| INFECTION_SITES = ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"] |
|
|
|
|
| |
|
|
| def _parse_notes(raw): |
| if not raw or raw in ("No lab data provided", "No MIC data available for trend analysis", ""): |
| return None |
| if isinstance(raw, (dict, list)): |
| return raw |
| try: |
| return json.loads(raw) |
| except Exception: |
| return None |
|
|
|
|
| def _build_rec_html(result: dict) -> str: |
| rec = result.get("recommendation") or {} |
| if not rec: |
| return '<div class="badge-info">No recommendation generated.</div>' |
| primary = rec.get("primary_antibiotic", "—") |
| dose = rec.get("dose", "—") |
| route = rec.get("route", "—") |
| freq = rec.get("frequency", "—") |
| duration = rec.get("duration", "—") |
| alt = rec.get("backup_antibiotic", "") |
| rationale = rec.get("rationale", "") |
| refs = rec.get("references", []) |
| alt_html = f"<br><strong>Alternative:</strong> {alt}" if alt else "" |
| rat_html = f"<br><br><strong>Clinical rationale</strong><br>{rationale}" if rationale else "" |
| ref_html = "" |
| if refs: |
| items = "".join(f"<li>{r}</li>" for r in refs) |
| ref_html = f"<br><strong>References</strong><ul style='margin:4px 0 0 16px'>{items}</ul>" |
| return f""" |
| <div class="rx-card"> |
| <div class="rx-symbol">℞</div> |
| <div class="rx-drug">{primary}</div><br> |
| <strong>Dose:</strong> {dose} · |
| <strong>Route:</strong> {route} · |
| <strong>Frequency:</strong> {freq} · |
| <strong>Duration:</strong> {duration} |
| {alt_html}{rat_html}{ref_html} |
| </div>""" |
|
|
|
|
| def _build_intake_html(result: dict) -> str: |
| intake = _parse_notes(result.get("intake_notes", "")) |
| crcl = result.get("creatinine_clearance_ml_min") |
| html = "" |
| if isinstance(intake, dict): |
| v = crcl or intake.get("creatinine_clearance_ml_min", 0) |
| sev = intake.get("infection_severity", "") |
| pathway = intake.get("recommended_stage", "") |
| cells = "" |
| if v: |
| cells += f"<td style='padding:8px 16px 8px 0'><strong>CrCl</strong><br>{float(v):.1f} mL/min</td>" |
| if sev: |
| cells += f"<td style='padding:8px 16px'><strong>Severity</strong><br>{sev.capitalize()}</td>" |
| if pathway: |
| cells += f"<td style='padding:8px 16px'><strong>Pathway</strong><br>{pathway.capitalize()}</td>" |
| if cells: |
| html += f"<table style='margin-bottom:12px'><tr>{cells}</tr></table>" |
| if intake.get("patient_summary"): |
| html += f'<div class="badge-info">{intake["patient_summary"]}</div>' |
| if intake.get("renal_dose_adjustment_needed"): |
| html += '<div class="badge-moderate" style="margin-top:8px">⚠ Renal dose adjustment required</div>' |
| if intake.get("identified_risk_factors"): |
| items = "".join(f"<li>{rf}</li>" for rf in intake["identified_risk_factors"]) |
| html += f"<br><strong>Identified risk factors</strong><ul style='margin:4px 0 0 16px'>{items}</ul>" |
| elif crcl: |
| html = f"<strong>CrCl:</strong> {float(crcl):.1f} mL/min" |
| else: |
| html = '<div class="badge-info">Intake summary not available.</div>' |
| return html |
|
|
|
|
| def _build_lab_html_and_df(result: dict) -> tuple[str, pd.DataFrame]: |
| vision = _parse_notes(result.get("vision_notes", "")) |
| trend = _parse_notes(result.get("trend_notes", "")) |
| html = "" |
| df = pd.DataFrame() |
|
|
| if vision is None: |
| html += '<div class="badge-info">No lab data processed. Provide lab results to activate the targeted pathway.</div>' |
| else: |
| v = vision if isinstance(vision, dict) else {} |
| if v.get("specimen_type"): |
| html += f"<strong>Specimen:</strong> {v['specimen_type'].capitalize()}<br>" |
| if v.get("extraction_confidence") is not None: |
| conf = float(v["extraction_confidence"]) |
| color = "#86efac" if conf >= 0.85 else "#fcd34d" if conf >= 0.6 else "#fca5a5" |
| html += (f'<div class="badge-info">Extraction confidence: ' |
| f'<span style="color:{color};font-weight:700">{conf:.0%}</span></div>') |
| orgs = v.get("identified_organisms", []) |
| if orgs: |
| items = "".join( |
| f"<li><strong>{o.get('organism_name','?')}</strong>" |
| + (f" — {o.get('significance','')}" if o.get("significance") else "") |
| + "</li>" |
| for o in orgs |
| ) |
| html += f"<br><strong>Identified organisms</strong><ul style='margin:4px 0 0 16px'>{items}</ul>" |
| sus = v.get("susceptibility_results", []) |
| if sus: |
| rows = [ |
| { |
| "Organism": e.get("organism", ""), |
| "Antibiotic": e.get("antibiotic", ""), |
| "MIC (mg/L)": str(e.get("mic_value", "")), |
| "Result": e.get("interpretation", ""), |
| } |
| for e in sus |
| ] |
| df = pd.DataFrame(rows) |
|
|
| if trend: |
| html += "<hr><strong>MIC Trend Analysis</strong><br>" |
| items = trend if isinstance(trend, list) else [trend] |
| for item in items: |
| if not isinstance(item, dict): |
| html += f"<p>{item}</p>" |
| continue |
| risk = item.get("risk_level", "UNKNOWN").upper() |
| css = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low") |
| icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓") |
| org = item.get("organism", "") |
| ab = item.get("antibiotic", "") |
| label = f"{org} / {ab} — " if (org or ab) else "" |
| html += (f'<div class="{css}">{icon} <strong>{label}{risk}</strong><br>' |
| f'<span style="font-size:0.88rem">{item.get("recommendation","")}</span></div>') |
| return html, df |
|
|
|
|
| def _build_safety_html(result: dict) -> str: |
| warnings = result.get("safety_warnings", []) |
| errors = result.get("errors", []) |
| html = "".join(f'<div class="badge-high">⚠ {w}</div>' for w in warnings) |
| if not warnings: |
| html = '<div class="badge-low">✓ No safety concerns identified.</div>' |
| html += "".join(f'<div class="badge-high" style="margin-top:6px">Error: {e}</div>' for e in errors) |
| return html |
|
|
|
|
| def _demo_result(patient_data: dict, has_labs: bool) -> dict: |
| result = { |
| "stage": "targeted" if has_labs else "empirical", |
| "creatinine_clearance_ml_min": 58.3, |
| "intake_notes": json.dumps({ |
| "patient_summary": ( |
| f"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} " |
| f"· {patient_data.get('suspected_source', 'infection')}" |
| ), |
| "creatinine_clearance_ml_min": 58.3, |
| "renal_dose_adjustment_needed": True, |
| "identified_risk_factors": patient_data.get("comorbidities", []), |
| "infection_severity": "moderate", |
| "recommended_stage": "targeted" if has_labs else "empirical", |
| }), |
| "recommendation": { |
| "primary_antibiotic": "Ciprofloxacin", |
| "dose": "500 mg", |
| "route": "Oral", |
| "frequency": "Every 12 hours", |
| "duration": "7 days", |
| "backup_antibiotic": "Nitrofurantoin 100 mg MR BD × 5 days", |
| "rationale": ( |
| "Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). " |
| "Ciprofloxacin provides broad Gram-negative coverage. " |
| "No dose adjustment required above CrCl 30 mL/min." |
| ), |
| "references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"], |
| }, |
| "safety_warnings": [], |
| "errors": [], |
| } |
| if has_labs: |
| result["vision_notes"] = json.dumps({ |
| "specimen_type": "urine", |
| "identified_organisms": [{"organism_name": "Escherichia coli", "significance": "pathogen"}], |
| "susceptibility_results": [ |
| {"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"}, |
| {"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"}, |
| {"organism": "E. coli", "antibiotic": "Ampicillin", "mic_value": ">32", "interpretation": "R"}, |
| ], |
| "extraction_confidence": 0.95, |
| }) |
| result["trend_notes"] = json.dumps([{ |
| "organism": "E. coli", "antibiotic": "Ciprofloxacin", |
| "risk_level": "LOW", "recommendation": "No MIC creep detected.", |
| }]) |
| return result |
|
|
|
|
| |
|
|
| def update_site_ui(site): |
| grp_updates = [gr.update(visible=(s == site)) for s in INFECTION_SITES] |
| src_choices = SUSPECTED_SOURCE_OPTIONS.get(site, []) or ["Other"] |
| prominent = site in CREATININE_PROMINENT_SITES |
| return ( |
| *grp_updates, |
| gr.update(choices=src_choices, value=src_choices[0]), |
| gr.update(visible=prominent), |
| gr.update(visible=not prominent), |
| gr.update(visible=False), |
| ) |
|
|
|
|
| def toggle_optional_creatinine(flag): |
| return gr.update(visible=bool(flag)) |
|
|
|
|
| def toggle_lab_inputs(method): |
| return ( |
| gr.update(visible=(method == "Upload file (PDF / image)")), |
| gr.update(visible=(method == "Paste lab text")), |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def run_pipeline_ui( |
| age, weight, height, sex, |
| creatinine_main, renal_flag, creatinine_optional, |
| infection_site, suspected_source, |
| |
| sf0, sf1, sf2, |
| |
| sf3, sf4, sf5, sf6, |
| |
| sf7, sf8, sf9, sf10, sf11, sf12, sf13, |
| |
| sf14, sf15, sf16, sf17, |
| |
| sf18, sf19, sf20, sf21, |
| |
| sf22, sf23, sf24, sf25, |
| |
| medications, allergies, comorbidities, risk_factors, |
| |
| lab_method, lab_file, lab_paste, |
| progress=gr.Progress(), |
| ): |
| |
| if infection_site in CREATININE_PROMINENT_SITES: |
| creatinine = creatinine_main |
| else: |
| creatinine = creatinine_optional if renal_flag else None |
|
|
| |
| site_vitals: dict = {} |
| if infection_site == "urinary": |
| site_vitals = { |
| "catheter_status": str(sf0 or ""), |
| "urinary_symptoms": ", ".join(sf1) if sf1 else "", |
| "urine_appearance": str(sf2 or ""), |
| } |
| elif infection_site == "respiratory": |
| site_vitals = { |
| "o2_saturation": str(sf3 or ""), |
| "ventilation_status": str(sf4 or ""), |
| "cough_type": str(sf5 or ""), |
| "sputum_character": str(sf6 or ""), |
| } |
| elif infection_site == "bloodstream": |
| site_vitals = { |
| "central_line_present": "Yes" if sf7 else "No", |
| "temperature_c": str(sf8 or ""), |
| "heart_rate_bpm": str(sf9 or ""), |
| "respiratory_rate": str(sf10 or ""), |
| "wbc_count": str(sf11 or ""), |
| "lactate_mmol": str(sf12 or ""), |
| "shock_status": str(sf13 or ""), |
| } |
| elif infection_site == "skin": |
| site_vitals = { |
| "wound_type": str(sf14 or ""), |
| "cellulitis_extent": str(sf15 or ""), |
| "abscess_present": "Yes" if sf16 else "No", |
| "foreign_body": "Yes" if sf17 else "No", |
| } |
| elif infection_site == "intra-abdominal": |
| site_vitals = { |
| "abdominal_pain_location": str(sf18 or ""), |
| "peritonitis_signs": ", ".join(sf19) if sf19 else "", |
| "perforation_suspected": "Yes" if sf20 else "No", |
| "ascites": "Yes" if sf21 else "No", |
| } |
| elif infection_site == "CNS": |
| site_vitals = { |
| "csf_obtained": "Yes" if sf22 else "No", |
| "neuro_symptoms": ", ".join(sf23) if sf23 else "", |
| "recent_neurosurgery": "Yes" if sf24 else "No", |
| "gcs_score": str(sf25 or ""), |
| } |
|
|
| |
| labs_raw_text = None |
| labs_image_bytes = None |
| if lab_method == "Upload file (PDF / image)" and lab_file is not None: |
| file_path = lab_file if isinstance(lab_file, str) else lab_file.name |
| ext = file_path.rsplit(".", 1)[-1].lower() |
| with open(file_path, "rb") as fh: |
| file_bytes = fh.read() |
| if ext == "pdf": |
| try: |
| import pypdf |
| reader = pypdf.PdfReader(BytesIO(file_bytes)) |
| extracted = "\n".join(p.extract_text() or "" for p in reader.pages).strip() |
| if extracted: |
| labs_raw_text = extracted |
| else: |
| labs_image_bytes = file_bytes |
| except Exception: |
| labs_image_bytes = file_bytes |
| else: |
| labs_image_bytes = file_bytes |
| elif lab_method == "Paste lab text" and lab_paste: |
| labs_raw_text = lab_paste.strip() or None |
|
|
| patient_data = { |
| "age_years": float(age or 65), |
| "weight_kg": float(weight or 70), |
| "height_cm": float(height or 170), |
| "sex": sex or "male", |
| "serum_creatinine_mg_dl": float(creatinine) if creatinine else None, |
| "infection_site": infection_site, |
| "suspected_source": suspected_source or f"{infection_site} infection", |
| "medications": [m.strip() for m in (medications or "").split("\n") if m.strip()], |
| "allergies": [a.strip() for a in (allergies or "").split("\n") if a.strip()], |
| "comorbidities": list(comorbidities or []) + list(risk_factors or []), |
| "vitals": site_vitals, |
| "labs_image_bytes": labs_image_bytes, |
| } |
|
|
| has_labs = bool(labs_raw_text or labs_image_bytes) |
| stages = ( |
| ["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"] |
| if has_labs else ["Intake Historian", "Clinical Pharmacologist"] |
| ) |
|
|
| for i, name in enumerate(stages): |
| progress((i + 0.5) / len(stages), desc=f"Running: {name}…") |
|
|
| try: |
| result = _run_pipeline_gpu(patient_data, labs_raw_text) |
| except Exception as e: |
| tb = traceback.format_exc() |
| logger.error("Pipeline failed — falling back to demo result.\n%s", tb) |
| result = _demo_result(patient_data, has_labs) |
| result["errors"].append(f"Pipeline error: {e}") |
| result["recommendation"] = {} |
|
|
| progress(1.0, desc="Complete") |
|
|
| rec_html = _build_rec_html(result) |
| intake_html = _build_intake_html(result) |
| lab_html, lab_df = _build_lab_html_and_df(result) |
| safety_html = _build_safety_html(result) |
|
|
| return rec_html, intake_html, lab_html, lab_df, safety_html, gr.update(visible=True) |
|
|
|
|
| |
|
|
| def switch_tool(tool): |
| tools = ["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"] |
| return [gr.update(visible=(t == tool)) for t in tools] |
|
|
|
|
| def run_empirical(infection_type, pathogen, risk): |
| guidance = get_empirical_therapy_guidance(infection_type, list(risk or [])) |
| html = "" |
| for i, rec in enumerate(guidance.get("recommendations", [])[:3], 1): |
| score = rec.get("relevance_score", 0) |
| content = rec.get("content", "") |
| source = rec.get("source", "IDSA Guidelines 2024") |
| html += (f'<div class="badge-info"><strong>Excerpt {i}</strong>' |
| f' (relevance {score:.2f})<br>{content}<br><em>Source: {source}</em></div>') |
| if pathogen: |
| effective = get_most_effective_antibiotics(pathogen, min_susceptibility=70) |
| if effective: |
| items = "".join( |
| f"<li><strong>{ab.get('antibiotic')}</strong>" |
| f" — {ab.get('avg_susceptibility', 0):.1f}% susceptible</li>" |
| for ab in effective[:6] |
| ) |
| html += f"<br><strong>Resistance data — {pathogen}</strong><ul style='margin:4px 0 0 16px'>{items}</ul>" |
| else: |
| html += '<div class="badge-info">No resistance data available for this pathogen.</div>' |
| return html or '<div class="badge-info">No results found.</div>' |
|
|
|
|
| def run_mic_interpret(pathogen, antibiotic, mic): |
| if not pathogen or not antibiotic: |
| return '<div class="badge-info">Enter pathogen and antibiotic.</div>' |
| result = interpret_mic_value(pathogen, antibiotic, float(mic or 1.0)) |
| interp = result.get("interpretation", "UNKNOWN") |
| msg = result.get("message", "") |
| if interp == "SUSCEPTIBLE": |
| return f'<div class="badge-low"><strong>Susceptible (S)</strong> — {msg}</div>' |
| if interp == "RESISTANT": |
| return f'<div class="badge-high"><strong>Resistant (R)</strong> — {msg}</div>' |
| return f'<div class="badge-moderate"><strong>Intermediate (I)</strong> — {msg}</div>' |
|
|
|
|
| def update_mic_inputs(n): |
| return [gr.update(visible=(i < int(n))) for i in range(6)] |
|
|
|
|
| def run_mic_trend(n, m0, m1, m2, m3, m4, m5): |
| vals = [m0, m1, m2, m3, m4, m5][: int(n)] |
| mic_values = [{"date": f"T{i}", "mic_value": float(v or 1.0)} for i, v in enumerate(vals)] |
| result = calculate_mic_trend(mic_values) |
| risk = result.get("risk_level", "UNKNOWN") |
| alert = result.get("alert", "") |
| css = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low") |
| icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓") |
| base = result.get("baseline_mic", "—") |
| curr = result.get("current_mic", "—") |
| ratio = result.get("ratio", "—") |
| return f""" |
| <div class="{css}">{icon} <strong>{risk} RISK</strong> — {alert}</div> |
| <br> |
| <table><tr> |
| <td style='padding:8px 24px 8px 0'><strong>Baseline MIC</strong><br>{base} mg/L</td> |
| <td style='padding:8px 24px'><strong>Current MIC</strong><br>{curr} mg/L</td> |
| <td style='padding:8px 24px'><strong>Fold change</strong><br>{ratio}×</td> |
| </tr></table>""" |
|
|
|
|
| def run_drug_safety(ab, meds, allergies_txt): |
| if not ab: |
| return '<div class="badge-info">Enter an antibiotic to check.</div>' |
| med_list = [m.strip() for m in (meds or "").split("\n") if m.strip()] |
| allergy_list = [a.strip() for a in (allergies_txt or "").split("\n") if a.strip()] |
| result = screen_antibiotic_safety(ab, med_list, allergy_list) |
| if result.get("safe_to_use"): |
| html = '<div class="badge-low">✓ No critical safety concerns identified.</div>' |
| else: |
| html = '<div class="badge-high">⚠ Safety concerns identified — review required.</div>' |
| html += "".join( |
| f'<div class="badge-moderate" style="margin-top:8px">⚠ {a.get("message","")}</div>' |
| for a in result.get("alerts", []) |
| ) |
| return html |
|
|
|
|
| def run_guidelines_search(query, pathogen_filter): |
| if not query: |
| return '<div class="badge-info">Enter a search query.</div>' |
| filt = None if pathogen_filter == "All" else pathogen_filter |
| results = search_clinical_guidelines(query, pathogen_filter=filt, n_results=5) |
| if not results: |
| return ('<div class="badge-info">No results found. Try broader search terms or ' |
| 'check that the knowledge base has been initialised.</div>') |
| html = "" |
| for i, r in enumerate(results, 1): |
| score = r.get("relevance_score", 0) |
| content = r.get("content", "") |
| source = r.get("source", "") |
| src_str = f"<br><em>Source: {source}</em>" if source else "" |
| html += (f'<div class="badge-info"><strong>Result {i}</strong>' |
| f' · relevance {score:.2f}<br>{content}{src_str}</div>') |
| return html |
|
|
|
|
| |
|
|
| def _make_site_widget(field): |
| ftype = field["type"] |
| label = field["label"] |
| if ftype == "selectbox": |
| return gr.Dropdown(choices=field["options"], value=field["options"][0], label=label) |
| if ftype == "multiselect": |
| return gr.CheckboxGroup(choices=field["options"], label=label) |
| if ftype == "number_input": |
| return gr.Number( |
| value=field.get("default", 0), label=label, |
| minimum=field.get("min"), maximum=field.get("max"), |
| ) |
| if ftype == "checkbox": |
| return gr.Checkbox(value=field.get("default", False), label=label) |
| return gr.Textbox(label=label) |
|
|
|
|
| |
|
|
| _s = get_settings() |
| OVERVIEW_MODELS_MD = f""" |
| | Agent | Role | Model | |
| |---|---|---| |
| | 1, 2, 4 | Clinical reasoning | `{_s.medgemma_4b_model or "google/medgemma-4b-it"}` | |
| | 3 | Trend analysis | `{_s.medgemma_27b_model or "google/medgemma-27b-text-it"}` | |
| | 4 (safety) | Pharmacology check | `{_s.txgemma_9b_model or "google/txgemma-9b-predict"}` | |
| | — | Semantic retrieval | `{_s.embedding_model_name}` | |
| | — | Inference backend | HuggingFace Transformers · {_s.quantization} quant | |
| """ |
|
|
| |
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="AMR-Guard") as demo: |
| gr.HTML(BANNER_HTML) |
|
|
| with gr.Tabs(): |
|
|
| |
| with gr.Tab("Overview"): |
| gr.HTML(""" |
| <div class="section-title">System Overview</div> |
| <div class="stat-cards"> |
| <div class="stat-card"> |
| <div class="label">WHO AWaRe</div><div class="value">264</div><div class="sub">antibiotics classified</div> |
| </div> |
| <div class="stat-card"> |
| <div class="label">EUCAST</div><div class="value">v16.0</div><div class="sub">breakpoint tables</div> |
| </div> |
| <div class="stat-card"> |
| <div class="label">IDSA</div><div class="value">2024</div><div class="sub">treatment guidelines</div> |
| </div> |
| <div class="stat-card"> |
| <div class="label">DDInter</div><div class="value">191K+</div><div class="sub">drug interactions</div> |
| </div> |
| </div> |
| <div class="section-title">Agent Pipeline</div> |
| """) |
| with gr.Row(): |
| with gr.Column(): |
| gr.HTML(""" |
| <p><strong>Stage 1 — Empirical</strong> <em>(no lab results yet)</em></p> |
| <div class="agent-step"><div class="num">Agent 01</div><div class="name">Intake Historian</div> |
| <div class="desc">Parses patient data, calculates CrCl, identifies MDR risk factors</div></div> |
| <div class="agent-step"><div class="num">Agent 04</div><div class="name">Clinical Pharmacologist</div> |
| <div class="desc">Empirical antibiotic selection · WHO AWaRe · safety screening</div></div> |
| """) |
| with gr.Column(): |
| gr.HTML(""" |
| <p><strong>Stage 2 — Targeted</strong> <em>(culture / sensitivity available)</em></p> |
| <div class="agent-step"><div class="num">Agent 01</div><div class="name">Intake Historian</div> |
| <div class="desc">Same as Stage 1</div></div> |
| <div class="agent-step"><div class="num">Agent 02</div><div class="name">Vision Specialist</div> |
| <div class="desc">Extracts structured data from lab reports (any language / format)</div></div> |
| <div class="agent-step"><div class="num">Agent 03</div><div class="name">Trend Analyst</div> |
| <div class="desc">Detects MIC creep · calculates resistance velocity</div></div> |
| <div class="agent-step"><div class="num">Agent 04</div><div class="name">Clinical Pharmacologist</div> |
| <div class="desc">Targeted recommendation informed by susceptibility data</div></div> |
| """) |
| gr.HTML('<div class="section-title">AI Models (Local)</div>') |
| gr.Markdown(OVERVIEW_MODELS_MD) |
| gr.HTML( |
| '<div class="disclaimer">⚠ <strong>Research demo only.</strong> ' |
| "Not validated for clinical use. All recommendations must be reviewed " |
| "by a licensed clinician before any patient-care decision.</div>" |
| ) |
|
|
| |
| with gr.Tab("Patient Analysis"): |
| gr.HTML('<div class="section-title">Patient Analysis Pipeline</div>') |
|
|
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| age = gr.Number(value=65, label="Age (years)", minimum=0, maximum=120, precision=0) |
| weight = gr.Number(value=70.0, label="Weight (kg)", minimum=1.0, maximum=300.0) |
| height = gr.Number(value=170.0,label="Height (cm)", minimum=50.0,maximum=250.0) |
| with gr.Column(scale=1): |
| sex = gr.Dropdown(choices=["male", "female"], value="male", label="Biological sex") |
| creatinine_main = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)", |
| minimum=0.1, maximum=20.0, visible=True) |
| renal_flag = gr.Checkbox(label="Known renal impairment / CKD?", visible=False) |
| creatinine_optional = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)", |
| minimum=0.1, maximum=20.0, visible=False) |
| with gr.Column(scale=1): |
| infection_site = gr.Dropdown(choices=INFECTION_SITES, value="urinary", |
| label="Primary infection site") |
| _init_src = SUSPECTED_SOURCE_OPTIONS.get("urinary", []) |
| suspected_source = gr.Dropdown(choices=_init_src, |
| value=_init_src[0] if _init_src else None, |
| label="Suspected source") |
|
|
| |
| site_groups: dict = {} |
| |
| u_comps: list = [] |
| r_comps: list = [] |
| b_comps: list = [] |
| sk_comps: list = [] |
| ia_comps: list = [] |
| cn_comps: list = [] |
|
|
| for site in INFECTION_SITES: |
| fields = SITE_SPECIFIC_FIELDS.get(site, []) |
| with gr.Group(visible=(site == "urinary")) as grp: |
| if fields: |
| gr.HTML(f'<div class="section-title">{site.title()} — Assessment</div>') |
| with gr.Row(): |
| for field in fields: |
| comp = _make_site_widget(field) |
| if site == "urinary": |
| u_comps.append(comp) |
| elif site == "respiratory": |
| r_comps.append(comp) |
| elif site == "bloodstream": |
| b_comps.append(comp) |
| elif site == "skin": |
| sk_comps.append(comp) |
| elif site == "intra-abdominal": |
| ia_comps.append(comp) |
| elif site == "CNS": |
| cn_comps.append(comp) |
| site_groups[site] = grp |
|
|
| |
| all_site_inputs = u_comps + r_comps + b_comps + sk_comps + ia_comps + cn_comps |
|
|
| |
| gr.HTML('<div class="section-title">Medical History</div>') |
| with gr.Row(): |
| with gr.Column(): |
| medications = gr.Textbox( |
| label="Current medications (one per line)", |
| placeholder="Metformin\nLisinopril", lines=4, |
| ) |
| allergies = gr.Textbox( |
| label="Drug allergies (one per line)", |
| placeholder="Penicillin\nSulfa", lines=3, |
| ) |
| with gr.Column(): |
| comorbidities = gr.CheckboxGroup( |
| choices=["Diabetes", "CKD", "Heart Failure", "COPD", |
| "Immunocompromised", "Recent Surgery", "Malignancy", "Liver Disease"], |
| label="Comorbidities", |
| ) |
| risk_factors = gr.CheckboxGroup( |
| choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated", |
| "Recent hospitalisation", "Nursing home", "Prior MDR infection"], |
| label="MDR risk factors", |
| ) |
|
|
| |
| gr.HTML('<div class="section-title">Lab / Culture Results ' |
| '<small>(optional — triggers targeted pathway)</small></div>') |
| lab_method = gr.Radio( |
| choices=["None — empirical pathway only", "Upload file (PDF / image)", "Paste lab text"], |
| value="None — empirical pathway only", |
| label="Input method", |
| ) |
| lab_file = gr.File( |
| label="Lab report", |
| file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"], |
| visible=False, |
| ) |
| lab_paste = gr.Textbox( |
| label="Lab report text", |
| placeholder=( |
| "Organism: Escherichia coli\n" |
| "Ciprofloxacin: S MIC 0.25\n" |
| "Nitrofurantoin: S MIC 16\n" |
| "Ampicillin: R MIC >32" |
| ), |
| lines=5, visible=False, |
| ) |
|
|
| run_btn = gr.Button("Run Agent Pipeline", variant="primary") |
|
|
| |
| with gr.Group(visible=False) as results_group: |
| gr.HTML('<div class="section-title">Results</div>') |
| with gr.Tabs(): |
| with gr.Tab("Recommendation"): |
| rec_out = gr.HTML() |
| with gr.Tab("Patient Summary"): |
| intake_out = gr.HTML() |
| with gr.Tab("Lab Analysis"): |
| lab_html_out = gr.HTML() |
| lab_df_out = gr.DataFrame(label="Susceptibility Table", wrap=True) |
| with gr.Tab("Safety"): |
| safety_out = gr.HTML() |
|
|
| |
| infection_site.change( |
| fn=update_site_ui, |
| inputs=[infection_site], |
| outputs=[ |
| *[site_groups[s] for s in INFECTION_SITES], |
| suspected_source, |
| creatinine_main, |
| renal_flag, |
| creatinine_optional, |
| ], |
| ) |
| renal_flag.change( |
| fn=toggle_optional_creatinine, |
| inputs=[renal_flag], |
| outputs=[creatinine_optional], |
| ) |
| lab_method.change( |
| fn=toggle_lab_inputs, |
| inputs=[lab_method], |
| outputs=[lab_file, lab_paste], |
| ) |
| _loading_html = '<div class="badge-info" style="padding:16px;text-align:center;">⏳ Pipeline running — please wait…</div>' |
| run_btn.click( |
| fn=lambda: ( |
| _loading_html, _loading_html, _loading_html, |
| pd.DataFrame(), _loading_html, |
| gr.update(visible=True), |
| ), |
| inputs=[], |
| outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group], |
| queue=False, |
| ).then( |
| fn=run_pipeline_ui, |
| inputs=[ |
| age, weight, height, sex, |
| creatinine_main, renal_flag, creatinine_optional, |
| infection_site, suspected_source, |
| *all_site_inputs, |
| medications, allergies, comorbidities, risk_factors, |
| lab_method, lab_file, lab_paste, |
| ], |
| outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group], |
| ) |
|
|
| |
| with gr.Tab("Clinical Tools"): |
| gr.HTML('<div class="section-title">Clinical Tools</div>') |
| tool_sel = gr.Dropdown( |
| choices=["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"], |
| value="Empirical Advisor", |
| label="Select tool", |
| ) |
|
|
| |
| with gr.Group(visible=True) as grp_ea: |
| with gr.Row(): |
| with gr.Column(scale=3): |
| ea_infection = gr.Dropdown( |
| choices=["Urinary Tract Infection", "Pneumonia", "Sepsis", |
| "Skin / Soft Tissue", "Intra-abdominal", "Meningitis"], |
| value="Urinary Tract Infection", label="Infection type", |
| ) |
| ea_pathogen = gr.Textbox( |
| label="Suspected pathogen (optional)", |
| placeholder="e.g., Klebsiella pneumoniae", |
| ) |
| ea_risk = gr.CheckboxGroup( |
| choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated", |
| "Immunocompromised", "Renal impairment", "Prior MDR"], |
| label="Risk factors", |
| ) |
| with gr.Column(scale=1): |
| gr.HTML(""" |
| <div class="badge-info"><strong style="color:#dceeff">WHO AWaRe</strong><br> |
| <span style="color:#86efac">●</span> Access — first-line<br> |
| <span style="color:#fcd34d">●</span> Watch — second-line<br> |
| <span style="color:#fca5a5">●</span> Reserve — last resort</div>""") |
| ea_btn = gr.Button("Get recommendation", variant="primary") |
| ea_out = gr.HTML() |
|
|
| |
| with gr.Group(visible=False) as grp_mi: |
| with gr.Row(): |
| with gr.Column(): |
| mi_pathogen = gr.Textbox(label="Pathogen", placeholder="e.g., Escherichia coli") |
| mi_antibiotic= gr.Textbox(label="Antibiotic", placeholder="e.g., Ciprofloxacin") |
| mi_mic = gr.Number(value=1.0, label="MIC value (mg/L)", minimum=0.001, maximum=1024.0) |
| with gr.Column(): |
| gr.HTML(""" |
| <div class="badge-info" style="margin-top:28px"><strong>Interpretation guide</strong><br><br> |
| <strong>S</strong> Susceptible — antibiotic is effective<br> |
| <strong>I</strong> Intermediate — effective at higher doses<br> |
| <strong>R</strong> Resistant — do not use</div>""") |
| mi_btn = gr.Button("Interpret", variant="primary") |
| mi_out = gr.HTML() |
|
|
| |
| with gr.Group(visible=False) as grp_mt: |
| mt_n = gr.Slider(minimum=2, maximum=6, value=3, step=1, |
| label="Number of historical readings") |
| with gr.Row(): |
| mt_m = [ |
| gr.Number(value=float(2 ** i), label=f"MIC {i+1} (mg/L)", |
| minimum=0.001, maximum=256.0, visible=(i < 3)) |
| for i in range(6) |
| ] |
| mt_btn = gr.Button("Analyse trend", variant="primary") |
| mt_out = gr.HTML() |
| mt_n.change(fn=update_mic_inputs, inputs=[mt_n], outputs=mt_m) |
|
|
| |
| with gr.Group(visible=False) as grp_ds: |
| with gr.Row(): |
| with gr.Column(): |
| ds_ab = gr.Textbox(label="Antibiotic to check", |
| placeholder="e.g., Ciprofloxacin") |
| ds_meds = gr.Textbox(label="Concurrent medications", |
| placeholder="Warfarin\nMetformin\nAmlodipine", lines=4) |
| with gr.Column(): |
| ds_allergies = gr.Textbox(label="Known allergies", |
| placeholder="Penicillin\nSulfa", lines=3) |
| ds_btn = gr.Button("Check safety", variant="primary") |
| ds_out = gr.HTML() |
|
|
| tool_sel.change( |
| fn=switch_tool, inputs=[tool_sel], |
| outputs=[grp_ea, grp_mi, grp_mt, grp_ds], |
| ) |
| ea_btn.click(fn=run_empirical, inputs=[ea_infection, ea_pathogen, ea_risk], outputs=[ea_out]) |
| mi_btn.click(fn=run_mic_interpret, inputs=[mi_pathogen, mi_antibiotic, mi_mic], outputs=[mi_out]) |
| mt_btn.click(fn=run_mic_trend, inputs=[mt_n, *mt_m], outputs=[mt_out]) |
| ds_btn.click(fn=run_drug_safety, inputs=[ds_ab, ds_meds, ds_allergies], outputs=[ds_out]) |
|
|
| |
| with gr.Tab("Guidelines"): |
| gr.HTML('<div class="section-title">Clinical Guidelines Search</div>') |
| with gr.Row(): |
| gl_query = gr.Textbox( |
| label="Search query", |
| placeholder="e.g., ESBL E. coli UTI treatment carbapenems", |
| scale=3, |
| ) |
| gl_filter = gr.Dropdown( |
| choices=["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"], |
| value="All", label="Filter by pathogen", scale=1, |
| ) |
| gl_btn = gr.Button("Search", variant="primary") |
| gl_out = gr.HTML() |
| gr.HTML( |
| '<div class="disclaimer">Sources: IDSA Treatment Guidelines 2024 · ' |
| "EUCAST Breakpoint Tables v16.0 · WHO EML · DDInter drug interaction database.</div>" |
| ) |
| gl_btn.click(fn=run_guidelines_search, inputs=[gl_query, gl_filter], outputs=[gl_out]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|