mgbam commited on
Commit
32f6443
·
verified ·
1 Parent(s): 5ddd206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +424 -600
app.py CHANGED
@@ -1,128 +1,76 @@
1
  """
2
- 🫁 AST Chest X-Ray Lab
3
- Multi-Class Chest X-Ray Detection (Normal · TB · Pneumonia · COVID-19)
4
- with Adaptive Sparse Training & Explainable AI (Grad-CAM)
 
 
 
 
 
 
5
  """
6
 
7
- import io
8
- from pathlib import Path
9
-
10
- import cv2
11
  import gradio as gr
12
- import matplotlib
13
- matplotlib.use("Agg") # safe backend for servers
14
- import matplotlib.pyplot as plt
15
- import numpy as np
16
  import torch
17
  import torch.nn as nn
18
- from PIL import Image
19
  from torchvision import models, transforms
 
 
 
 
 
 
20
 
21
  # ============================================================================
22
  # Model Setup
23
  # ============================================================================
24
 
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
-
27
- # EfficientNet backbone we want 4 output classes
28
- NUM_CLASSES = 4
29
- model = models.efficientnet_b0(weights=None)
30
- model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
31
-
32
- # We expect a 4-class EfficientNet checkpoint here
33
- checkpoint_candidates = [
34
- "checkpoints/best.pt", # main location from your screenshot
35
- "best.pt", # optional fallback in repo root
36
- ]
37
-
38
- MODEL_LOAD_INFO = ""
39
- loaded = False
40
-
41
-
42
- def extract_state_dict(ckpt):
43
- """
44
- Handle both:
45
- - plain state_dict (just parameter tensors)
46
- - training checkpoints with keys like 'model_state_dict', 'state_dict', etc.
47
- """
48
- if isinstance(ckpt, dict):
49
- for key in ["model_state_dict", "state_dict", "model"]:
50
- if key in ckpt and isinstance(ckpt[key], dict):
51
- return ckpt[key]
52
- return ckpt # already a raw state dict
53
-
54
-
55
- for ckpt_path in checkpoint_candidates:
56
- if Path(ckpt_path).is_file():
57
- try:
58
- print(f"🔍 Trying to load weights from: {ckpt_path}")
59
- raw_ckpt = torch.load(ckpt_path, map_location=device)
60
- state_dict = extract_state_dict(raw_ckpt)
61
-
62
- # Check classifier size to ensure it's truly 4-class
63
- if "classifier.1.weight" in state_dict:
64
- out_features = state_dict["classifier.1.weight"].shape[0]
65
- if out_features != NUM_CLASSES:
66
- raise ValueError(
67
- f"Checkpoint at {ckpt_path} has {out_features} output "
68
- f"classes, but this app expects {NUM_CLASSES}."
69
- )
70
-
71
- # Load strict – we want the full EfficientNet weights
72
- model.load_state_dict(state_dict, strict=True)
73
-
74
- MODEL_LOAD_INFO = (
75
- f"✅ Model loaded from **{ckpt_path}** on **{device.type.upper()}**."
76
- )
77
- loaded = True
78
- break
79
- except Exception as e:
80
- print(f"⚠️ Found {ckpt_path} but failed to load model_state_dict: {e}")
81
-
82
- if not loaded:
83
- raise RuntimeError(
84
- "Model file not found or could not be loaded.\n"
85
- "Expected a 4-class EfficientNet checkpoint at 'checkpoints/best.pt' "
86
- "or 'best.pt' that was saved with model.state_dict().\n"
87
- "If you saved a training checkpoint, make sure it has a "
88
- "'model_state_dict' key with the 4-class EfficientNet weights."
89
- )
90
 
91
  model = model.to(device)
92
  model.eval()
93
 
94
- TOTAL_PARAMS = sum(p.numel() for p in model.parameters())
95
- TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6
96
-
97
- # ============================================================================
98
- # Classes & Preprocessing
99
- # ============================================================================
100
-
101
- CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
102
  CLASS_COLORS = {
103
- "Normal": "#22c55e", # Green
104
- "Tuberculosis": "#ef4444", # Red
105
- "Pneumonia": "#f97316", # Orange
106
- "COVID-19": "#a855f7", # Purple
107
  }
108
 
109
- transform = transforms.Compose(
110
- [
111
- transforms.Resize(256),
112
- transforms.CenterCrop(224),
113
- transforms.ToTensor(),
114
- transforms.Normalize(
115
- [0.485, 0.456, 0.406],
116
- [0.229, 0.224, 0.225],
117
- ),
118
- ]
119
- )
120
 
121
  # ============================================================================
122
  # Grad-CAM Implementation
123
  # ============================================================================
124
 
125
-
126
  class GradCAM:
127
  def __init__(self, model, target_layer):
128
  self.model = model
@@ -147,7 +95,7 @@ class GradCAM:
147
 
148
  self.model.zero_grad()
149
  one_hot = torch.zeros_like(output)
150
- one_hot[0, target_class] = 1
151
  output.backward(gradient=one_hot, retain_graph=True)
152
 
153
  if self.gradients is None:
@@ -161,623 +109,471 @@ class GradCAM:
161
 
162
  return cam, output
163
 
164
-
165
  target_layer = model.features[-1]
166
  grad_cam = GradCAM(model, target_layer)
167
 
168
  # ============================================================================
169
- # Visualization Helpers
170
  # ============================================================================
171
 
 
 
 
 
 
 
172
 
173
- def _figure_to_pil():
174
- buf = io.BytesIO()
175
- plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
176
- plt.close()
177
- buf.seek(0)
178
- return Image.open(buf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  def create_original_display(image, pred_label, confidence):
182
- fig, ax = plt.subplots(figsize=(7, 7))
 
183
  ax.imshow(image)
184
- ax.axis("off")
185
 
 
186
  color = CLASS_COLORS[pred_label]
187
- title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%"
188
- ax.set_title(
189
- title,
190
- fontsize=16,
191
- fontweight="bold",
192
- color=color,
193
- pad=20,
194
- )
195
  plt.tight_layout()
196
- return _figure_to_pil()
197
 
 
 
 
 
 
198
 
199
- def create_gradcam_visualization(image, cam):
 
 
 
 
200
  img_array = np.array(image.resize((224, 224)))
201
  cam_resized = cv2.resize(cam, (224, 224))
202
 
 
203
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
204
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
205
 
206
- fig, ax = plt.subplots(figsize=(7, 7))
207
  ax.imshow(heatmap)
208
- ax.axis("off")
209
- ax.set_title(
210
- "Attention Heatmap\n(Where the model is looking)",
211
- fontsize=14,
212
- fontweight="bold",
213
- pad=20,
214
- )
215
  plt.tight_layout()
216
- return _figure_to_pil()
217
 
 
 
 
 
 
 
218
 
219
  def create_overlay_visualization(image, cam):
 
220
  img_array = np.array(image.resize((224, 224))) / 255.0
221
  cam_resized = cv2.resize(cam, (224, 224))
222
 
 
223
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
224
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
225
 
 
226
  overlay = img_array * 0.5 + heatmap * 0.5
227
  overlay = np.clip(overlay, 0, 1)
228
 
229
- fig, ax = plt.subplots(figsize=(7, 7))
230
  ax.imshow(overlay)
231
- ax.axis("off")
232
- ax.set_title(
233
- "Explainable AI Overlay\n(Anatomy + Attention)",
234
- fontsize=14,
235
- fontweight="bold",
236
- pad=20,
237
- )
238
  plt.tight_layout()
239
- return _figure_to_pil()
240
 
241
- # ============================================================================
242
- # Interpretation
243
- # ============================================================================
 
244
 
 
245
 
246
- def create_interpretation(pred_label, confidence, results, audience="Clinician"):
247
- header_note = {
248
- "Clinician": "This view is tuned for **clinical decision support** (not a replacement for your judgement).",
249
- "Researcher": "This view is tuned for **model behavior understanding** and experimental workflows.",
250
- "Patient / Public": "This view is tuned for **patient-friendly language**. Always discuss results with a doctor.",
251
- }.get(audience, "Use this output as a **screening aid**, not a final diagnosis.")
252
 
253
  interpretation = f"""
254
- ## 🔬 Analysis Results ({audience} View)
255
-
256
- > {header_note}
257
-
258
- ### Primary Prediction: **{pred_label}**
259
  - Confidence: **{confidence:.1f}%**
260
-
261
- ### Probability Breakdown
262
  - 🟢 Normal: **{results['Normal']:.1f}%**
263
  - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
264
  - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
265
  - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
266
-
267
  ---
268
  """
269
 
270
- if pred_label == "Tuberculosis":
 
271
  if confidence >= 85:
272
  interpretation += """
273
- ### 🧫 TB Pattern – High Confidence
274
-
275
- The model has detected features strongly suggestive of **pulmonary tuberculosis**.
276
-
277
- **Suggested Clinical Pathway**
278
- 1. Immediate review by a clinician or chest physician
279
- 2. Sputum testing (AFB smear, GeneXpert MTB/RIF, or TB-PCR)
280
- 3. Correlate with symptoms: chronic cough, weight loss, night sweats, fever, hemoptysis
281
- 4. Consider CT or further imaging if uncertainty remains
282
- 5. Infection control and contact tracing if TB is confirmed
 
 
 
 
283
  """
284
  else:
285
  interpretation += """
286
- ### 🧫 TB Pattern – Possible
287
-
288
- The scan shows features that **could** be compatible with tuberculosis, but confidence is moderate.
289
-
290
- - Correlate with symptoms and risk factors
291
- - Consider sputum testing where suspicion is non-trivial
292
- - Follow-up imaging as clinically indicated
 
293
  """
294
 
295
- elif pred_label == "Pneumonia":
296
  if confidence >= 85:
297
  interpretation += """
298
- ### 🌫 Pneumonia Pattern – High Confidence
299
-
300
- The model has detected an opacity pattern consistent with **pneumonia**.
301
-
302
- Typical clinical correlates:
303
-
304
- - Fever, productive cough
305
- - Shortness of breath
306
- - Pleuritic chest pain
307
-
308
- Use alongside clinical exam, labs, and local treatment guidelines.
 
 
 
 
 
309
  """
310
  else:
311
  interpretation += """
312
- ### 🌫 Pneumonia Pattern – Possible
313
-
314
- Findings may be compatible with pneumonia, but alternative explanations exist.
315
-
316
- - Check vital signs and auscultation
317
- - Labs (WBC, CRP, cultures) may be useful
318
- - Consider watchful follow-up or repeat imaging
319
  """
320
 
321
- elif pred_label == "COVID-19":
322
  if confidence >= 85:
323
  interpretation += """
324
- ### 🦠 COVID-19 Pattern – High Confidence
325
-
326
- Distribution and appearance of opacities are compatible with **COVID-19 pneumonia**.
327
-
328
- ⚠️ Imaging alone is **not diagnostic**.
329
-
330
- - Confirm with RT-PCR or validated antigen testing
331
- - Follow local isolation and infection-control policies
332
- - Monitor SpO₂ and respiratory status; escalate care if deterioration occurs
 
 
 
 
 
 
 
 
 
 
333
  """
334
  else:
335
  interpretation += """
336
- ### 🦠 COVID-19 Pattern – Possible
337
-
338
- Some features may overlap with COVID-19, but there is substantial uncertainty.
339
-
340
- - Testing (RT-PCR / antigen) is essential
341
- - Integrate exposure history and symptoms
 
 
342
  """
343
 
344
  else: # Normal
345
  if confidence >= 85:
346
  interpretation += """
347
- ### No Major Abnormality Detected
348
-
349
- The model did **not** detect strong features of TB, pneumonia, or COVID-19.
350
-
351
- Important caveats:
352
-
353
- - Early disease or small lesions may be missed
354
- - Non-infective conditions (e.g., cancer, ILD) are **not** specifically evaluated
355
- - Persistent symptoms still warrant clinical review
 
 
 
 
 
 
 
356
  """
357
  else:
358
  interpretation += """
359
- ### ℹ️ Likely Normal, But Low Confidence
360
-
361
- The scan leans towards **normal**, but the model is not highly confident.
362
-
363
- - Consider follow-up or additional tests if symptoms persist
 
 
364
  """
365
 
 
366
  interpretation += """
367
  ---
368
  ## ⚠️ CRITICAL MEDICAL DISCLAIMER
369
-
370
- - This AI model is a **screening / decision-support tool only**
371
- - It is **not FDA-approved** and must **not** be used as a stand-alone diagnostic device
372
- - Always integrate:
373
- - Clinical history and examination
374
- - Laboratory tests (sputum, PCR, cultures, etc.)
375
- - Expert radiologist review
376
-
377
- **Gold Standards**
378
-
379
- - TB: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR
380
- - Pneumonia: Clinical diagnosis + labs / microbiology
381
- - COVID-19: RT-PCR or validated antigen tests
382
-
383
- When in doubt, consult a qualified healthcare professional.
 
 
 
 
 
 
 
384
  ---
385
- 🫁 **Powered by Adaptive Sparse Training (AST)**
386
- Energy-efficient deep learning designed to make chest X-ray screening more accessible.
387
-
388
- **Links**
389
-
390
- - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
391
- - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
392
  """
393
- return interpretation
394
-
395
- # ============================================================================
396
- # Prediction Pipeline
397
- # ============================================================================
398
-
399
-
400
- def predict_chest_xray(image, show_gradcam=True, audience="Clinician"):
401
- """
402
- Main inference function used by Gradio.
403
- Returns:
404
- - dict of class probabilities
405
- - annotated original
406
- - grad-cam heatmap
407
- - overlay
408
- - full markdown report
409
- - short textual snapshot
410
- """
411
- if image is None:
412
- msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report."
413
- return {}, None, None, None, msg, "Awaiting image upload…"
414
-
415
- if isinstance(image, np.ndarray):
416
- image = Image.fromarray(image).convert("RGB")
417
- else:
418
- image = image.convert("RGB")
419
-
420
- original_img = image.copy()
421
- input_tensor = transform(image).unsqueeze(0).to(device)
422
 
423
- with torch.set_grad_enabled(show_gradcam):
424
- if show_gradcam:
425
- cam, output = grad_cam.generate(input_tensor)
426
- else:
427
- output = model(input_tensor)
428
- cam = None
429
-
430
- probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
431
- prob_sum = float(np.sum(probs))
432
-
433
- if not (0.99 <= prob_sum <= 1.01):
434
- print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.")
435
-
436
- pred_class = int(output.argmax(dim=1).item())
437
- pred_label = CLASSES[pred_class]
438
- confidence = float(probs[pred_class]) * 100.0
439
-
440
- results = {
441
- CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
442
- for i in range(len(CLASSES))
443
- }
444
-
445
- original_pil = create_original_display(original_img, pred_label, confidence)
446
- gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None
447
- overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None
448
-
449
- interpretation = create_interpretation(pred_label, confidence, results, audience=audience)
450
- snapshot = f"**{pred_label}** · {confidence:.1f}% confidence • Prob. sum: {prob_sum:.3f}"
451
-
452
- return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot
453
 
454
  # ============================================================================
455
- # WOW UI / UX – Gradio App
456
  # ============================================================================
457
 
 
458
  custom_css = """
459
- :root {
460
- --primary: #6366f1;
461
- --primary-soft: rgba(99, 102, 241, 0.12);
462
- --accent: #ec4899;
463
- }
464
-
465
- .gradio-container {
466
- font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
467
- background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%);
468
- color: #e5e7eb;
469
- }
470
-
471
- #hero {
472
- padding: 24px 24px 8px 24px;
473
- border-radius: 24px;
474
- background: linear-gradient(120deg, rgba(99,102,241,0.18), rgba(236,72,153,0.14));
475
- border: 1px solid rgba(148, 163, 184, 0.4);
476
- box-shadow: 0 24px 60px rgba(15,23,42,0.85);
477
- backdrop-filter: blur(18px);
478
- }
479
-
480
- .hero-title {
481
- font-size: 2.4rem;
482
- font-weight: 800;
483
- letter-spacing: 0.04em;
484
- color: #f9fafb;
485
- margin-bottom: 6px;
486
- }
487
-
488
- .hero-subtitle {
489
- font-size: 0.98rem;
490
- color: #e5e7eb;
491
  }
492
-
493
- .hero-chip-row {
494
- display: flex;
495
- flex-wrap: wrap;
496
- gap: 8px;
497
- margin-top: 14px;
498
- }
499
-
500
- .hero-chip {
501
- padding: 4px 10px;
502
- border-radius: 999px;
503
- font-size: 0.78rem;
504
- background: rgba(15,23,42,0.8);
505
- border: 1px solid rgba(148,163,184,0.5);
506
- display: inline-flex;
507
- align-items: center;
508
- gap: 6px;
509
- color: #e5e7eb;
510
- }
511
-
512
- .pulse-dot {
513
- width: 8px;
514
- height: 8px;
515
- border-radius: 999px;
516
- background: #22c55e;
517
- box-shadow: 0 0 0 0 rgba(34,197,94,0.7);
518
- animation: pulse 1.4s infinite;
519
- }
520
-
521
- @keyframes pulse {
522
- 0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); }
523
- 70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); }
524
- 100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); }
525
- }
526
-
527
- .glass-card {
528
- background: rgba(15,23,42,0.82);
529
- border-radius: 18px;
530
- border: 1px solid rgba(148,163,184,0.4);
531
- box-shadow: 0 18px 40px rgba(15,23,42,0.85);
532
- padding: 18px;
533
- backdrop-filter: blur(16px);
534
  }
535
-
536
- .glass-card-light {
537
- background: rgba(15,23,42,0.65);
538
- border-radius: 18px;
539
- border: 1px solid rgba(148,163,184,0.3);
540
- box-shadow: 0 12px 24px rgba(15,23,42,0.85);
541
- padding: 16px;
542
- backdrop-filter: blur(12px);
543
  }
544
-
545
- .stat-pill {
546
- padding: 10px 12px;
547
- border-radius: 14px;
548
- background: rgba(15,23,42,0.9);
549
- border: 1px solid rgba(148,163,184,0.5);
550
- font-size: 0.78rem;
551
- display: flex;
552
- flex-direction: column;
553
- gap: 2px;
554
  }
555
-
556
- .stat-pill-label {
557
- color: #9ca3af;
558
- text-transform: uppercase;
559
- font-size: 0.68rem;
560
  }
561
-
562
- .stat-pill-value {
563
- color: #e5e7eb;
564
- font-weight: 600;
 
565
  }
566
-
567
- .dropzone-image img {
568
- border-radius: 16px !important;
 
 
569
  }
570
-
571
- .output-image img {
572
- border-radius: 16px !important;
573
  }
574
-
575
  footer {
576
  text-align: center;
577
- margin-top: 24px;
578
- color: #9ca3af;
579
- font-size: 0.78rem;
580
  }
581
  """
582
 
583
- theme = gr.themes.Soft(
584
- primary_hue="indigo",
585
- secondary_hue="pink",
586
- neutral_hue="slate",
587
- ).set(
588
- button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)",
589
- button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)",
590
- )
591
-
592
- with gr.Blocks(css=custom_css, theme=theme) as demo:
593
- # HERO
594
- gr.HTML(
595
- f"""
596
- <div id="hero">
597
- <div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;">
598
- <div>
599
- <div class="hero-title">🫁 AST Chest X-Ray Lab</div>
600
- <div class="hero-subtitle">
601
- Multi-class chest X-ray analysis with <b>Explainable AI</b> and
602
- <b>Adaptive Sparse Training</b> – Normal · Tuberculosis · Pneumonia · COVID-19.
603
- </div>
604
- <div class="hero-chip-row">
605
- <div class="hero-chip">
606
- <span class="pulse-dot"></span>
607
- Live Inference
608
- </div>
609
- <div class="hero-chip">
610
- EfficientNet-B0 · ~{TOTAL_PARAMS_M:.1f}M params
611
- </div>
612
- <div class="hero-chip">
613
- 95–97% validation accuracy · ~89% energy savings
614
- </div>
615
- <div class="hero-chip">
616
- {MODEL_LOAD_INFO}
617
- </div>
618
- </div>
619
- </div>
620
- <div style="min-width:210px;display:flex;flex-direction:column;gap:8px;">
621
- <div class="stat-pill">
622
- <div class="stat-pill-label">Device</div>
623
- <div class="stat-pill-value">{device.type.upper()}</div>
624
- </div>
625
- <div class="stat-pill">
626
- <div class="stat-pill-label">Task</div>
627
- <div class="stat-pill-value">Normal · TB · Pneumonia · COVID-19</div>
628
- </div>
629
- </div>
630
  </div>
631
  </div>
632
- """
633
- )
634
-
635
- gr.Markdown(" ")
636
-
637
- with gr.Row(equal_height=True):
638
- # LEFT: INPUT PANEL
639
- with gr.Column(scale=1, elem_classes="glass-card"):
640
- gr.Markdown("### 1️⃣ Upload & Configure")
641
 
 
 
 
642
  image_input = gr.Image(
643
  type="pil",
644
- label="Drop a chest X-ray here",
645
- elem_classes=["dropzone-image"],
646
  )
647
 
648
- with gr.Row():
649
- show_gradcam = gr.Checkbox(
650
- value=True,
651
- label="Explainable AI (Grad-CAM)",
652
- info="Highlight regions that drive the prediction",
653
- )
654
- audience_select = gr.Radio(
655
- ["Clinician", "Researcher", "Patient / Public"],
656
- value="Clinician",
657
- label="Report Style",
658
- )
659
-
660
- with gr.Row():
661
- analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3)
662
- clear_btn = gr.Button("🧹 Reset", variant="secondary")
663
-
664
- gr.Markdown(
665
- """
666
- **Tips**
667
 
668
- - Use frontal (PA/AP) chest X-rays in PNG / JPG format
669
- - This tool is best used as a **triage / screening assistant**
670
- - For noisy or rotated images, consider preprocessing before upload
671
- """
672
  )
673
 
674
- # RIGHT: RESULTS PANEL
675
- with gr.Column(scale=2, elem_classes="glass-card-light"):
676
- gr.Markdown("### 2️⃣ AI Dashboard")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  with gr.Tabs():
679
- with gr.Tab("Snapshot"):
680
- snapshot_output = gr.Markdown(
681
- "No scan analyzed yet. Upload an X-ray to get started."
682
- )
683
- prob_output = gr.Label(
684
- label="Prediction Confidence (All Classes)",
685
- num_top_classes=4,
686
- )
687
-
688
- with gr.Tab("Visual Explanations"):
689
- with gr.Row():
690
- original_output = gr.Image(
691
- label="Annotated X-ray",
692
- elem_classes=["output-image"],
693
- )
694
- gradcam_output = gr.Image(
695
- label="Attention Heatmap",
696
- elem_classes=["output-image"],
697
- )
698
- overlay_output = gr.Image(
699
- label="Explainable Overlay",
700
- elem_classes=["output-image"],
701
  )
702
 
703
- with gr.Tab("Full Report"):
704
- interpretation_output = gr.Markdown(
705
- "The full clinical / research report will appear here after inference."
 
706
  )
707
 
708
- with gr.Tab("Model Card"):
709
- gr.Markdown(
710
- f"""
711
- ### 🧠 Model Card – AST Chest X-Ray
712
-
713
- - **Backbone**: EfficientNet-B0
714
- - **Classes**: Normal, Tuberculosis, Pneumonia, COVID-19
715
- - **Optimization**: Sample-based Adaptive Sparse Training (AST)
716
- - **Energy Profile**: ~89% training energy reduction vs dense baseline
717
-
718
- **Goals**
719
-
720
- 1. Provide **fast, explainable triage** support for TB & pneumonia
721
- 2. Maintain high specificity, especially for TB vs pneumonia
722
- 3. Be lightweight enough for deployment in **resource-constrained settings**
723
-
724
- > This model is a research prototype. Do **not** use it as a stand-alone clinical device.
725
- """
726
  )
727
 
728
- gr.Markdown("---")
729
-
730
- gr.HTML(
731
- """
732
- <footer>
733
- <p>
734
- <b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/>
735
- Built for research, education, and early-stage screening support.
736
- </p>
737
- <p style="margin-top:6px;">
738
- ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is not FDA-approved and cannot replace a clinician
739
- or radiologist. All decisions must be made by qualified healthcare professionals.
740
- </p>
741
- </footer>
742
- """
743
- )
744
-
745
- # Wiring
746
- analyze_btn.click(
747
- fn=predict_chest_xray,
748
- inputs=[image_input, show_gradcam, audience_select],
749
- outputs=[
750
- prob_output,
751
- original_output,
752
- gradcam_output,
753
- overlay_output,
754
- interpretation_output,
755
- snapshot_output,
756
- ],
757
- )
758
-
759
- clear_btn.click(
760
- fn=lambda: (
761
- {},
762
- None,
763
- None,
764
- None,
765
- "Awaiting image upload…",
766
- "Awaiting image upload…",
767
- ),
768
- inputs=None,
769
- outputs=[
770
- prob_output,
771
- original_output,
772
- gradcam_output,
773
- overlay_output,
774
- interpretation_output,
775
- snapshot_output,
776
- ],
777
- )
778
 
779
- # Example X-rays (optional – comment out if you don't have these files)
780
- gr.Markdown("### 🔍 Try Example X-rays")
781
  gr.Examples(
782
  examples=[
783
  ["examples/normal.png"],
@@ -786,16 +582,44 @@ with gr.Blocks(css=custom_css, theme=theme) as demo:
786
  ["examples/covid.png"],
787
  ],
788
  inputs=image_input,
 
789
  )
790
 
791
- # ============================================================================
792
- # Launch
793
- # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
 
795
  if __name__ == "__main__":
796
  demo.launch(
797
  share=False,
798
  server_name="0.0.0.0",
799
  server_port=7860,
800
- show_error=True,
801
  )
 
 
1
  """
2
+ 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
3
+ Advanced Gradio Interface - 4 Disease Classes
4
+ Features:
5
+ - Real-time detection: Normal, TB, Pneumonia, COVID-19
6
+ - Grad-CAM visualization (explainable AI)
7
+ - Improved specificity - distinguishes TB from pneumonia
8
+ - Confidence scores with visual indicators
9
+ - Clinical interpretation and recommendations
10
+ - Mobile-responsive design
11
  """
12
 
 
 
 
 
13
  import gradio as gr
 
 
 
 
14
  import torch
15
  import torch.nn as nn
 
16
  from torchvision import models, transforms
17
+ from PIL import Image
18
+ import numpy as np
19
+ import cv2
20
+ import matplotlib.pyplot as plt
21
+ from pathlib import Path
22
+ import io
23
 
24
  # ============================================================================
25
  # Model Setup
26
  # ============================================================================
27
 
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+
30
+ # Load model - Using EfficientNet-B2 (trained model architecture)
31
+ model = models.efficientnet_b2(weights=None)
32
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
33
+
34
+ try:
35
+ # Try loading best.pt from root directory (HuggingFace Spaces location)
36
+ model.load_state_dict(torch.load('best.pt', map_location=device))
37
+ print(" Multi-class model loaded successfully from best.pt!")
38
+ except Exception as e:
39
+ print(f"⚠️ Error loading model from best.pt: {e}")
40
+ try:
41
+ # Fallback to checkpoints directory
42
+ model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device))
43
+ print("✅ Multi-class model loaded successfully from checkpoints/best_multiclass.pt!")
44
+ except Exception as e2:
45
+ print(f"❌ CRITICAL ERROR: Could not load model from any location!")
46
+ print(f" - best.pt error: {e}")
47
+ print(f" - checkpoints/best_multiclass.pt error: {e2}")
48
+ raise RuntimeError("Model file not found! Please ensure best.pt is uploaded to the Space.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  model = model.to(device)
51
  model.eval()
52
 
53
+ # Classes
54
+ CLASSES = ['Normal', 'Tuberculosis', 'Pneumonia', 'COVID-19']
 
 
 
 
 
 
55
  CLASS_COLORS = {
56
+ 'Normal': '#2ecc71', # Green
57
+ 'Tuberculosis': '#e74c3c', # Red
58
+ 'Pneumonia': '#f39c12', # Orange
59
+ 'COVID-19': '#9b59b6' # Purple
60
  }
61
 
62
+ # Image preprocessing
63
+ transform = transforms.Compose([
64
+ transforms.Resize(256),
65
+ transforms.CenterCrop(224),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
68
+ ])
 
 
 
 
69
 
70
  # ============================================================================
71
  # Grad-CAM Implementation
72
  # ============================================================================
73
 
 
74
  class GradCAM:
75
  def __init__(self, model, target_layer):
76
  self.model = model
 
95
 
96
  self.model.zero_grad()
97
  one_hot = torch.zeros_like(output)
98
+ one_hot[0][target_class] = 1
99
  output.backward(gradient=one_hot, retain_graph=True)
100
 
101
  if self.gradients is None:
 
109
 
110
  return cam, output
111
 
112
+ # Setup Grad-CAM
113
  target_layer = model.features[-1]
114
  grad_cam = GradCAM(model, target_layer)
115
 
116
  # ============================================================================
117
+ # Prediction Functions
118
  # ============================================================================
119
 
120
+ def predict_chest_xray(image, show_gradcam=True):
121
+ """
122
+ Predict disease class from chest X-ray with Grad-CAM visualization
123
+ """
124
+ if image is None:
125
+ return None, None, None, None
126
 
127
+ # Convert to PIL if needed
128
+ if isinstance(image, np.ndarray):
129
+ image = Image.fromarray(image).convert('RGB')
130
+ else:
131
+ image = image.convert('RGB')
132
+
133
+ # Store original for display
134
+ original_img = image.copy()
135
+
136
+ # Preprocess
137
+ input_tensor = transform(image).unsqueeze(0).to(device)
138
+
139
+ # Get prediction with Grad-CAM
140
+ with torch.set_grad_enabled(show_gradcam):
141
+ if show_gradcam:
142
+ cam, output = grad_cam.generate(input_tensor)
143
+ else:
144
+ output = model(input_tensor)
145
+ cam = None
146
+
147
+ # Get probabilities
148
+ probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
149
+
150
+ # Safety check: ensure probabilities sum to ~1.0
151
+ prob_sum = np.sum(probs)
152
+ if not (0.99 <= prob_sum <= 1.01):
153
+ print(f"⚠️ WARNING: Probability sum is {prob_sum}, not 1.0. Model may not be loaded correctly!")
154
+
155
+ pred_class = int(output.argmax(dim=1).item())
156
+ pred_label = CLASSES[pred_class]
157
+ confidence = float(probs[pred_class]) * 100
158
+
159
+ # Create results - ensure values are between 0-100
160
+ results = {
161
+ CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100))) for i in range(len(CLASSES))
162
+ }
163
+
164
+ # Generate visualizations
165
+ original_pil = create_original_display(original_img, pred_label, confidence)
166
 
167
+ if cam is not None and show_gradcam:
168
+ gradcam_viz = create_gradcam_visualization(original_img, cam, pred_label, confidence)
169
+ overlay_viz = create_overlay_visualization(original_img, cam)
170
+ else:
171
+ gradcam_viz = None
172
+ overlay_viz = None
173
+
174
+ # Create interpretation text
175
+ interpretation = create_interpretation(pred_label, confidence, results)
176
+
177
+ return results, original_pil, gradcam_viz, overlay_viz, interpretation
178
 
179
  def create_original_display(image, pred_label, confidence):
180
+ """Create annotated original image"""
181
+ fig, ax = plt.subplots(figsize=(8, 8))
182
  ax.imshow(image)
183
+ ax.axis('off')
184
 
185
+ # Add prediction box
186
  color = CLASS_COLORS[pred_label]
187
+ title = f'Prediction: {pred_label}\nConfidence: {confidence:.1f}%'
188
+ ax.set_title(title, fontsize=16, fontweight='bold', color=color, pad=20)
189
+
 
 
 
 
 
190
  plt.tight_layout()
 
191
 
192
+ # Convert to PIL
193
+ buf = io.BytesIO()
194
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
195
+ plt.close()
196
+ buf.seek(0)
197
 
198
+ return Image.open(buf)
199
+
200
+ def create_gradcam_visualization(image, cam, pred_label, confidence):
201
+ """Create Grad-CAM heatmap"""
202
+ # Resize CAM to image size
203
  img_array = np.array(image.resize((224, 224)))
204
  cam_resized = cv2.resize(cam, (224, 224))
205
 
206
+ # Create heatmap
207
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
208
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
209
 
210
+ fig, ax = plt.subplots(figsize=(8, 8))
211
  ax.imshow(heatmap)
212
+ ax.axis('off')
213
+ ax.set_title('Attention Heatmap\n(Areas the model focuses on)',
214
+ fontsize=14, fontweight='bold', pad=20)
215
+
 
 
 
216
  plt.tight_layout()
 
217
 
218
+ buf = io.BytesIO()
219
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
220
+ plt.close()
221
+ buf.seek(0)
222
+
223
+ return Image.open(buf)
224
 
225
  def create_overlay_visualization(image, cam):
226
+ """Create overlay of image and heatmap"""
227
  img_array = np.array(image.resize((224, 224))) / 255.0
228
  cam_resized = cv2.resize(cam, (224, 224))
229
 
230
+ # Create heatmap
231
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
232
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
233
 
234
+ # Overlay
235
  overlay = img_array * 0.5 + heatmap * 0.5
236
  overlay = np.clip(overlay, 0, 1)
237
 
238
+ fig, ax = plt.subplots(figsize=(8, 8))
239
  ax.imshow(overlay)
240
+ ax.axis('off')
241
+ ax.set_title('Explainable AI Visualization\n(Original + Heatmap)',
242
+ fontsize=14, fontweight='bold', pad=20)
243
+
 
 
 
244
  plt.tight_layout()
 
245
 
246
+ buf = io.BytesIO()
247
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
248
+ plt.close()
249
+ buf.seek(0)
250
 
251
+ return Image.open(buf)
252
 
253
+ def create_interpretation(pred_label, confidence, results):
254
+ """Create interpretation text with improved medical disclaimers"""
 
 
 
 
255
 
256
  interpretation = f"""
257
+ ## 🔬 Analysis Results
258
+ ### Prediction: **{pred_label}**
 
 
 
259
  - Confidence: **{confidence:.1f}%**
260
+ ### Probability Breakdown:
 
261
  - 🟢 Normal: **{results['Normal']:.1f}%**
262
  - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
263
  - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
264
  - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
 
265
  ---
266
  """
267
 
268
+ # Disease-specific interpretations
269
+ if pred_label == 'Tuberculosis':
270
  if confidence >= 85:
271
  interpretation += """
272
+ **⚠️ High Confidence TB Detection**
273
+ The model has detected features highly consistent with tuberculosis infection.
274
+ **CRITICAL - Immediate Actions Required:**
275
+ 1. ✅ **Immediate consultation** with a healthcare provider
276
+ 2. ✅ **Confirmatory sputum test** (AFB smear or GeneXpert MTB/RIF)
277
+ 3. **Clinical correlation** with symptoms:
278
+ - Persistent cough (>2 weeks)
279
+ - Fever, especially night sweats
280
+ - Unexplained weight loss
281
+ - Hemoptysis (coughing blood)
282
+ 4. ✅ **Isolation** and contact tracing if confirmed
283
+ 5. ✅ **Chest CT scan** if needed for further evaluation
284
+ **⚠️ IMPORTANT**: This is a SCREENING tool, not a diagnostic tool.
285
+ Clinical diagnosis of TB requires laboratory confirmation (sputum test).
286
  """
287
  else:
288
  interpretation += """
289
+ **⚠️ Possible TB Detection**
290
+ The model has detected features suggestive of tuberculosis, but confidence is moderate.
291
+ **Recommended Actions:**
292
+ 1. Consult healthcare provider for clinical evaluation
293
+ 2. Consider confirmatory sputum testing
294
+ 3. Evaluate clinical symptoms
295
+ 4. Follow-up imaging may be recommended
296
+ **Note**: Moderate confidence requires professional medical evaluation.
297
  """
298
 
299
+ elif pred_label == 'Pneumonia':
300
  if confidence >= 85:
301
  interpretation += """
302
+ **⚠️ High Confidence Pneumonia Detection**
303
+ The model has detected features consistent with pneumonia (bacterial or viral).
304
+ **Recommended Actions:**
305
+ 1. ✅ **Medical evaluation** for pneumonia diagnosis
306
+ 2. **Possible confirmatory tests**:
307
+ - Sputum culture
308
+ - Blood tests (WBC count, CRP)
309
+ - Additional chest imaging if needed
310
+ 3. **Clinical correlation** with symptoms:
311
+ - Cough with sputum production
312
+ - Fever and chills
313
+ - Shortness of breath
314
+ - Chest pain with breathing
315
+ 4. ✅ **Treatment**: Antibiotics (bacterial) or supportive care (viral)
316
+ **Note**: Pneumonia can present similarly to other lung diseases.
317
+ Professional diagnosis is essential for appropriate treatment.
318
  """
319
  else:
320
  interpretation += """
321
+ **⚠️ Possible Pneumonia**
322
+ Features suggest possible pneumonia, but further evaluation is needed.
323
+ **Recommended Actions:**
324
+ 1. Seek medical evaluation
325
+ 2. Clinical symptom assessment
326
+ 3. Consider additional diagnostic tests
327
+ **Note**: Requires professional medical evaluation for confirmation.
328
  """
329
 
330
+ elif pred_label == 'COVID-19':
331
  if confidence >= 85:
332
  interpretation += """
333
+ **⚠️ High Confidence COVID-19 Detection**
334
+ The model has detected features consistent with COVID-19 pneumonia.
335
+ **URGENT - Immediate Actions:**
336
+ 1. ✅ **COVID-19 RT-PCR test** for confirmation
337
+ 2. **Isolation** to prevent transmission
338
+ 3. ✅ **Monitor oxygen saturation** (SpO2 levels)
339
+ 4. **Seek immediate medical care** if:
340
+ - Difficulty breathing
341
+ - SpO2 < 94%
342
+ - Persistent chest pain
343
+ - Confusion or inability to stay awake
344
+ 5. ✅ **Contact tracing** if positive
345
+ **Clinical Symptoms to Monitor:**
346
+ - Fever, cough, shortness of breath
347
+ - Loss of taste/smell
348
+ - Fatigue, body aches
349
+ - Gastrointestinal symptoms
350
+ **⚠️ IMPORTANT**: Imaging findings alone cannot confirm COVID-19.
351
+ RT-PCR or antigen testing is required for diagnosis.
352
  """
353
  else:
354
  interpretation += """
355
+ **⚠️ Possible COVID-19**
356
+ Features suggest possible COVID-19, but confirmation testing is essential.
357
+ **Recommended Actions:**
358
+ 1. Get RT-PCR or rapid antigen test
359
+ 2. Self-isolate pending test results
360
+ 3. Monitor symptoms
361
+ 4. Seek medical care if symptoms worsen
362
+ **Note**: COVID-19 diagnosis requires laboratory confirmation.
363
  """
364
 
365
  else: # Normal
366
  if confidence >= 85:
367
  interpretation += """
368
+ **✅ High Confidence Normal Result**
369
+ The model has not detected significant abnormalities consistent with TB, pneumonia, or COVID-19.
370
+ **Interpretation:**
371
+ - Chest X-ray appears within normal limits
372
+ - No features of active tuberculosis detected
373
+ - No signs of pneumonia or COVID-19
374
+ **Important Notes:**
375
+ - This does NOT rule out all lung diseases
376
+ - Early-stage diseases may not show on X-ray
377
+ - If you have symptoms, seek medical evaluation
378
+ - Regular health screenings are recommended
379
+ **When to still see a doctor:**
380
+ - Persistent cough, fever, or respiratory symptoms
381
+ - Unexplained weight loss or night sweats
382
+ - Shortness of breath or chest pain
383
+ - Known exposure to TB or COVID-19
384
  """
385
  else:
386
  interpretation += """
387
+ **⚠️ Likely Normal, Low Confidence**
388
+ The model suggests a normal chest X-ray, but confidence is not high.
389
+ **Recommended Actions:**
390
+ 1. If symptomatic, seek medical evaluation
391
+ 2. Consider repeat imaging if concerns persist
392
+ 3. Clinical correlation is important
393
+ **Note**: Low confidence results should be reviewed by healthcare professionals.
394
  """
395
 
396
+ # Add universal disclaimer
397
  interpretation += """
398
  ---
399
  ## ⚠️ CRITICAL MEDICAL DISCLAIMER
400
+ ### Model Capabilities:
401
+ - Trained on 4 disease classes: Normal, TB, Pneumonia, COVID-19
402
+ - Can distinguish between different lung diseases
403
+ - ~95-97% accuracy in validation testing
404
+ - Powered by Adaptive Sparse Training (89% energy efficient)
405
+ ### Important Limitations:
406
+ - ⚠️ This is a **SCREENING tool**, not a diagnostic device
407
+ - ⚠️ **NOT FDA-approved** for clinical diagnosis
408
+ - ⚠️ Cannot detect: lung cancer, pulmonary fibrosis, bronchiectasis, other rare diseases
409
+ - ⚠️ Cannot replace: professional radiologist review
410
+ - ⚠️ Cannot confirm: laboratory diagnosis (sputum tests, PCR, cultures)
411
+ ### Clinical Use Guidelines:
412
+ 1. Use as a **preliminary screening** tool only
413
+ 2. ✅ ALL positive results require **confirmatory laboratory testing**
414
+ 3. ALL cases require **clinical correlation** with symptoms and history
415
+ 4. ✅ Expert radiologist review is recommended for clinical decisions
416
+ 5. ✅ Do NOT initiate treatment based solely on AI predictions
417
+ ### Diagnostic Gold Standards:
418
+ - **TB**: Sputum AFB smear/culture, GeneXpert MTB/RIF, TB-PCR
419
+ - **Pneumonia**: Clinical diagnosis + sputum culture + blood tests
420
+ - **COVID-19**: RT-PCR, rapid antigen test
421
+ **When in doubt, always consult a qualified healthcare professional.**
422
  ---
423
+ 🫁 **Powered by Adaptive Sparse Training**
424
+ Energy-efficient AI for accessible healthcare
425
+ **Learn more:**
426
+ - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
427
+ - Research: Sample-based Adaptive Sparse Training for deep learning
 
 
428
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ return interpretation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  # ============================================================================
433
+ # Gradio Interface
434
  # ============================================================================
435
 
436
+ # Custom CSS
437
  custom_css = """
438
+ #main-container {
439
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
440
+ padding: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  }
442
+ #title {
443
+ text-align: center;
444
+ color: white;
445
+ font-size: 2.5em;
446
+ font-weight: bold;
447
+ margin-bottom: 10px;
448
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  }
450
+ #subtitle {
451
+ text-align: center;
452
+ color: #f0f0f0;
453
+ font-size: 1.2em;
454
+ margin-bottom: 20px;
 
 
 
455
  }
456
+ #stats {
457
+ text-align: center;
458
+ color: #fff;
459
+ font-size: 0.95em;
460
+ margin-bottom: 30px;
461
+ padding: 15px;
462
+ background: rgba(255,255,255,0.1);
463
+ border-radius: 10px;
464
+ backdrop-filter: blur(10px);
 
465
  }
466
+ .gradio-container {
467
+ font-family: 'Inter', sans-serif;
 
 
 
468
  }
469
+ #upload-box {
470
+ border: 3px dashed #667eea;
471
+ border-radius: 15px;
472
+ padding: 20px;
473
+ background: rgba(255,255,255,0.95);
474
  }
475
+ #results-box {
476
+ background: white;
477
+ border-radius: 15px;
478
+ padding: 20px;
479
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
480
  }
481
+ .output-image {
482
+ border-radius: 10px;
483
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
484
  }
 
485
  footer {
486
  text-align: center;
487
+ margin-top: 30px;
488
+ color: white;
489
+ font-size: 0.9em;
490
  }
491
  """
492
 
493
+ # Create interface
494
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
495
+ gr.HTML("""
496
+ <div id="main-container">
497
+ <div id="title">🫁 Multi-Class Chest X-Ray Detection AI</div>
498
+ <div id="subtitle">Advanced chest X-ray analysis with Explainable AI</div>
499
+ <div id="stats">
500
+ <b>95-97% Accuracy</b> across 4 disease classes |
501
+ <b>89% Energy Efficient</b> |
502
+ Powered by Adaptive Sparse Training
503
+ <br><br>
504
+ <b>Detects:</b> Normal • Tuberculosis • Pneumonia • COVID-19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  </div>
506
  </div>
507
+ """)
 
 
 
 
 
 
 
 
508
 
509
+ with gr.Row():
510
+ with gr.Column(scale=1, elem_id="upload-box"):
511
+ gr.Markdown("## 📤 Upload Chest X-Ray")
512
  image_input = gr.Image(
513
  type="pil",
514
+ label="Upload X-Ray Image",
515
+ elem_classes="output-image"
516
  )
517
 
518
+ show_gradcam = gr.Checkbox(
519
+ value=True,
520
+ label="Enable Grad-CAM Visualization (Explainable AI)",
521
+ info="Shows which areas the model focuses on"
522
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
+ analyze_btn = gr.Button(
525
+ "🔬 Analyze X-Ray",
526
+ variant="primary",
527
+ size="lg"
528
  )
529
 
530
+ gr.Markdown("""
531
+ ### 📋 Supported Images:
532
+ - Chest X-rays (PA or AP view)
533
+ - PNG, JPG, JPEG formats
534
+ - Grayscale or RGB
535
+ ### ⚡ What's New:
536
+ - ✅ **Improved Specificity**: Can distinguish TB from Pneumonia
537
+ - ✅ **4 Disease Classes**: Normal, TB, Pneumonia, COVID-19
538
+ - ✅ **Fewer False Positives**: <5% on pneumonia cases
539
+ - ✅ **Same Energy Efficiency**: 89% savings with AST
540
+ """)
541
+
542
+ with gr.Column(scale=2, elem_id="results-box"):
543
+ gr.Markdown("## 📊 Analysis Results")
544
+
545
+ # Results display
546
+ with gr.Row():
547
+ prob_output = gr.Label(
548
+ label="Prediction Confidence",
549
+ num_top_classes=4
550
+ )
551
 
552
  with gr.Tabs():
553
+ with gr.Tab("Original"):
554
+ original_output = gr.Image(
555
+ label="Annotated X-Ray",
556
+ elem_classes="output-image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  )
558
 
559
+ with gr.Tab("Grad-CAM Heatmap"):
560
+ gradcam_output = gr.Image(
561
+ label="Attention Heatmap",
562
+ elem_classes="output-image"
563
  )
564
 
565
+ with gr.Tab("Overlay"):
566
+ overlay_output = gr.Image(
567
+ label="Explainable AI Visualization",
568
+ elem_classes="output-image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  )
570
 
571
+ interpretation_output = gr.Markdown(
572
+ label="Clinical Interpretation"
573
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
+ # Example images
576
+ gr.Markdown("## 📁 Example X-Rays")
577
  gr.Examples(
578
  examples=[
579
  ["examples/normal.png"],
 
582
  ["examples/covid.png"],
583
  ],
584
  inputs=image_input,
585
+ label="Click to load example"
586
  )
587
 
588
+ # Connect components
589
+ analyze_btn.click(
590
+ fn=predict_chest_xray,
591
+ inputs=[image_input, show_gradcam],
592
+ outputs=[prob_output, original_output, gradcam_output, overlay_output, interpretation_output]
593
+ )
594
+
595
+ # Footer
596
+ gr.HTML("""
597
+ <footer>
598
+ <p>
599
+ <b>🫁 Multi-Class Chest X-Ray Detection with AST</b><br>
600
+ Trained on Normal, Tuberculosis, Pneumonia, and COVID-19 cases<br>
601
+ 95-97% Accuracy | 89% Energy Savings | Explainable AI<br><br>
602
+ <a href="https://github.com/oluwafemidiakhoa/Tuberculosis" target="_blank" style="color: white;">
603
+ 📂 GitHub Repository
604
+ </a> |
605
+ <a href="https://huggingface.co/spaces/mgbam/Tuberculosis" target="_blank" style="color: white;">
606
+ 🤗 Hugging Face Space
607
+ </a>
608
+ </p>
609
+ <p style="font-size: 0.8em; margin-top: 15px;">
610
+ ⚠️ <b>MEDICAL DISCLAIMER</b>: This is a screening tool, not a diagnostic device.
611
+ All predictions require professional medical evaluation and laboratory confirmation.
612
+ Not FDA-approved for clinical use.
613
+ </p>
614
+ </footer>
615
+ """)
616
 
617
+ # Launch
618
  if __name__ == "__main__":
619
  demo.launch(
620
  share=False,
621
  server_name="0.0.0.0",
622
  server_port=7860,
623
+ show_error=True
624
  )
625
+