mgbam commited on
Commit
1d113ad
Β·
verified Β·
1 Parent(s): 9b5af19

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +709 -0
app.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
3
+ 4-Class Screening: Normal, Tuberculosis, Pneumonia, COVID-19
4
+
5
+ Mission:
6
+ This open research tool is being built to help humanity –
7
+ especially patients and clinicians in low-resource settings –
8
+ by providing energy-efficient, explainable AI support for chest
9
+ X-ray screening. It is a digital second opinion, NOT a replacement
10
+ for radiologists or doctors.
11
+ """
12
+
13
+ import gradio as gr
14
+ import torch
15
+ import torch.nn as nn
16
+ from torchvision import models, transforms
17
+ from PIL import Image
18
+ import numpy as np
19
+ import cv2
20
+ import matplotlib.pyplot as plt
21
+ from pathlib import Path
22
+ import io
23
+
24
+ # ============================================================================
25
+ # Model Setup
26
+ # ============================================================================
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def load_efficientnet_model():
32
+ """
33
+ Build EfficientNet-B2 and load your working 4-class best.pt checkpoint.
34
+
35
+ We intentionally keep this simple and very close to the version you
36
+ already confirmed is working, to avoid shape-mismatch issues.
37
+ """
38
+ # Base architecture: EfficientNet-B2
39
+ model = models.efficientnet_b2(weights=None)
40
+ in_features = model.classifier[1].in_features
41
+ model.classifier[1] = nn.Linear(in_features, 4) # 4 classes
42
+
43
+ # Where we expect your weights to live
44
+ candidate_paths = [
45
+ Path("checkpoints/best.pt"), # HF Space path (from your screenshot)
46
+ Path("best.pt"), # fallback for local runs
47
+ ]
48
+
49
+ last_error = None
50
+ for ckpt_path in candidate_paths:
51
+ if not ckpt_path.exists():
52
+ print(f"⚠️ Checkpoint not found at {ckpt_path}")
53
+ continue
54
+
55
+ try:
56
+ print(f"πŸ” Loading weights from: {ckpt_path}")
57
+ state = torch.load(ckpt_path, map_location=device)
58
+
59
+ # If it comes from a training script with wrappers
60
+ if isinstance(state, dict):
61
+ if "model_state_dict" in state:
62
+ state = state["model_state_dict"]
63
+ elif "state_dict" in state:
64
+ state = state["state_dict"]
65
+
66
+ # This is the same idea as your original working call
67
+ missing, unexpected = model.load_state_dict(state, strict=False)
68
+ if missing or unexpected:
69
+ print(f" ⚠️ Non-critical keys - missing: {missing}, unexpected: {unexpected}")
70
+ print(f"βœ… Model weights successfully loaded from {ckpt_path}")
71
+ model.to(device)
72
+ model.eval()
73
+ return model
74
+ except Exception as e:
75
+ print(f"❌ Could not load from {ckpt_path}: {e}")
76
+ last_error = e
77
+
78
+ raise RuntimeError(
79
+ "Could not load EfficientNet-B2 4-class weights from any known path.\n"
80
+ f"Last error: {last_error}"
81
+ )
82
+
83
+
84
+ model = load_efficientnet_model()
85
+
86
+ # Classes
87
+ CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
88
+ CLASS_COLORS = {
89
+ "Normal": "#2ecc71", # Green
90
+ "Tuberculosis": "#e74c3c", # Red
91
+ "Pneumonia": "#f39c12", # Orange
92
+ "COVID-19": "#9b59b6", # Purple
93
+ }
94
+
95
+ # Image preprocessing
96
+ transform = transforms.Compose(
97
+ [
98
+ transforms.Resize(256),
99
+ transforms.CenterCrop(224),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize(
102
+ [0.485, 0.456, 0.406],
103
+ [0.229, 0.224, 0.225],
104
+ ),
105
+ ]
106
+ )
107
+
108
+ # ============================================================================
109
+ # Grad-CAM Implementation
110
+ # ============================================================================
111
+
112
+
113
+ class GradCAM:
114
+ def __init__(self, model, target_layer):
115
+ self.model = model
116
+ self.target_layer = target_layer
117
+ self.gradients = None
118
+ self.activations = None
119
+
120
+ def save_gradient(grad):
121
+ self.gradients = grad
122
+
123
+ def save_activation(module, input, output):
124
+ self.activations = output.detach()
125
+
126
+ # Forward hook: store activations
127
+ target_layer.register_forward_hook(save_activation)
128
+ # Backward hook: store gradients
129
+ target_layer.register_full_backward_hook(
130
+ lambda m, grad_in, grad_out: save_gradient(grad_out[0])
131
+ )
132
+
133
+ def generate(self, input_image, target_class=None):
134
+ output = self.model(input_image)
135
+
136
+ if target_class is None:
137
+ target_class = output.argmax(dim=1)
138
+
139
+ self.model.zero_grad()
140
+ one_hot = torch.zeros_like(output)
141
+ one_hot[0][target_class] = 1
142
+ output.backward(gradient=one_hot, retain_graph=True)
143
+
144
+ if self.gradients is None or self.activations is None:
145
+ return None, output
146
+
147
+ # Global average pooling over gradients
148
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
149
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
150
+ cam = torch.relu(cam)
151
+ cam = cam.squeeze().cpu().numpy()
152
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
153
+
154
+ return cam, output
155
+
156
+
157
+ # Setup Grad-CAM on the last feature layer
158
+ target_layer = model.features[-1]
159
+ grad_cam = GradCAM(model, target_layer)
160
+
161
+ # ============================================================================
162
+ # Prediction & Visualization
163
+ # ============================================================================
164
+
165
+
166
+ def predict_chest_xray(image, show_gradcam=True):
167
+ """
168
+ Predict disease class from chest X-ray with Grad-CAM visualization.
169
+
170
+ Returns:
171
+ - class probabilities dict
172
+ - annotated original image
173
+ - Grad-CAM heatmap image
174
+ - overlay image
175
+ - markdown clinical interpretation
176
+ """
177
+ if image is None:
178
+ return None, None, None, None, "Please upload a chest X-ray."
179
+
180
+ # Convert to PIL if needed
181
+ if isinstance(image, np.ndarray):
182
+ image = Image.fromarray(image).convert("RGB")
183
+ else:
184
+ image = image.convert("RGB")
185
+
186
+ # Keep original for visualization
187
+ original_img = image.copy()
188
+
189
+ # Preprocess
190
+ input_tensor = transform(image).unsqueeze(0).to(device)
191
+
192
+ # Forward + optional Grad-CAM
193
+ with torch.set_grad_enabled(show_gradcam):
194
+ if show_gradcam:
195
+ cam, output = grad_cam.generate(input_tensor)
196
+ else:
197
+ cam = None
198
+ output = model(input_tensor)
199
+
200
+ # Probabilities
201
+ probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
202
+ prob_sum = float(np.sum(probs))
203
+ if not (0.99 <= prob_sum <= 1.01):
204
+ print(f"⚠️ Probability sum is {prob_sum:.4f}, expected ~1.0 – check model weights.")
205
+
206
+ pred_class = int(output.argmax(dim=1).item())
207
+ pred_label = CLASSES[pred_class]
208
+ confidence = float(probs[pred_class] * 100.0)
209
+
210
+ # Ensure values between 0–100
211
+ results = {
212
+ CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
213
+ for i in range(len(CLASSES))
214
+ }
215
+
216
+ # Visualizations
217
+ original_pil = create_original_display(original_img, pred_label, confidence)
218
+
219
+ if cam is not None and show_gradcam:
220
+ gradcam_viz = create_gradcam_visualization(
221
+ original_img, cam, pred_label, confidence
222
+ )
223
+ overlay_viz = create_overlay_visualization(original_img, cam)
224
+ else:
225
+ gradcam_viz = None
226
+ overlay_viz = None
227
+
228
+ # Interpretation text
229
+ interpretation = create_interpretation(pred_label, confidence, results)
230
+
231
+ return results, original_pil, gradcam_viz, overlay_viz, interpretation
232
+
233
+
234
+ def create_original_display(image, pred_label, confidence):
235
+ """Create annotated original image."""
236
+ fig, ax = plt.subplots(figsize=(8, 8))
237
+ ax.imshow(image)
238
+ ax.axis("off")
239
+
240
+ color = CLASS_COLORS[pred_label]
241
+ title = f"Prediction: {pred_label}\nConfidence: {confidence:.1f}%"
242
+ ax.set_title(title, fontsize=16, fontweight="bold", color=color, pad=20)
243
+
244
+ plt.tight_layout()
245
+ buf = io.BytesIO()
246
+ plt.savefig(
247
+ buf,
248
+ format="png",
249
+ dpi=150,
250
+ bbox_inches="tight",
251
+ facecolor="white",
252
+ )
253
+ plt.close()
254
+ buf.seek(0)
255
+
256
+ return Image.open(buf)
257
+
258
+
259
+ def create_gradcam_visualization(image, cam, pred_label, confidence):
260
+ """Create Grad-CAM heatmap."""
261
+ img_array = np.array(image.resize((224, 224)))
262
+ cam_resized = cv2.resize(cam, (224, 224))
263
+
264
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
265
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
266
+
267
+ fig, ax = plt.subplots(figsize=(8, 8))
268
+ ax.imshow(heatmap)
269
+ ax.axis("off")
270
+ ax.set_title(
271
+ "Attention Heatmap\n(Areas the model focuses on)",
272
+ fontsize=14,
273
+ fontweight="bold",
274
+ pad=20,
275
+ )
276
+
277
+ plt.tight_layout()
278
+ buf = io.BytesIO()
279
+ plt.savefig(
280
+ buf,
281
+ format="png",
282
+ dpi=150,
283
+ bbox_inches="tight",
284
+ facecolor="white",
285
+ )
286
+ plt.close()
287
+ buf.seek(0)
288
+
289
+ return Image.open(buf)
290
+
291
+
292
+ def create_overlay_visualization(image, cam):
293
+ """Overlay original image and Grad-CAM heatmap."""
294
+ img_array = np.array(image.resize((224, 224))) / 255.0
295
+ cam_resized = cv2.resize(cam, (224, 224))
296
+
297
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
298
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
299
+
300
+ overlay = img_array * 0.5 + heatmap * 0.5
301
+ overlay = np.clip(overlay, 0, 1)
302
+
303
+ fig, ax = plt.subplots(figsize=(8, 8))
304
+ ax.imshow(overlay)
305
+ ax.axis("off")
306
+ ax.set_title(
307
+ "Explainable AI Visualization\n(Original + Heatmap)",
308
+ fontsize=14,
309
+ fontweight="bold",
310
+ pad=20,
311
+ )
312
+
313
+ plt.tight_layout()
314
+ buf = io.BytesIO()
315
+ plt.savefig(
316
+ buf,
317
+ format="png",
318
+ dpi=150,
319
+ bbox_inches="tight",
320
+ facecolor="white",
321
+ )
322
+ plt.close()
323
+ buf.seek(0)
324
+
325
+ return Image.open(buf)
326
+
327
+
328
+ def create_interpretation(pred_label, confidence, results):
329
+ """
330
+ Clinical-style interpretation text with strong global-health framing
331
+ and strict medical disclaimer.
332
+ """
333
+
334
+ interpretation = f"""
335
+ ## 🫁 AI Chest X-Ray Screening – Global Health Edition
336
+ This tool is part of an open effort to **support clinicians and patients worldwide**,
337
+ especially in places where radiologists are scarce.
338
+
339
+ ---
340
+
341
+ ## πŸ”¬ Analysis Summary
342
+ **Predicted class:** **{pred_label}**
343
+ **Model confidence:** **{confidence:.1f}%**
344
+
345
+ ### Probability Breakdown
346
+ - 🟒 Normal: **{results['Normal']:.1f}%**
347
+ - πŸ”΄ Tuberculosis: **{results['Tuberculosis']:.1f}%**
348
+ - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
349
+ - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
350
+
351
+ ---
352
+ """
353
+
354
+ # Disease-specific details
355
+ if pred_label == "Tuberculosis":
356
+ if confidence >= 85:
357
+ interpretation += """
358
+ ### ⚠️ High-Confidence Tuberculosis Pattern Detected
359
+
360
+ The AI model has found features strongly suggestive of **pulmonary tuberculosis (TB)**.
361
+
362
+ **Suggested next steps for a clinical team (NOT automatic orders):**
363
+ 1. Correlate with symptoms:
364
+ - Cough > 2 weeks
365
+ - Night sweats, fever
366
+ - Weight loss
367
+ - Hemoptysis (coughing blood)
368
+ 2. Order **confirmatory TB tests**:
369
+ - Sputum smear / culture
370
+ - GeneXpert MTB/RIF or TB-PCR
371
+ 3. Consider **isolation** and **contact screening** if TB is suspected.
372
+ 4. Evaluate HIV status and comorbidities according to local guidelines.
373
+
374
+ ➑️ This system is designed to **support TB programs** in low-resource settings,
375
+ where early triage can save lives.
376
+ """
377
+ else:
378
+ interpretation += """
379
+ ### ⚠️ Possible Tuberculosis Features
380
+
381
+ The model sees **TB-like patterns**, but confidence is moderate.
382
+
383
+ **Recommended clinical follow-up (not automatic diagnosis):**
384
+ - Detailed history and physical examination
385
+ - Evaluate TB risk factors and symptoms
386
+ - Consider sputum-based TB testing
387
+ - Repeat imaging or CT if clinically indicated
388
+ """
389
+
390
+ elif pred_label == "Pneumonia":
391
+ if confidence >= 85:
392
+ interpretation += """
393
+ ### ⚠️ High-Confidence Pneumonia Pattern
394
+
395
+ The model detects findings consistent with **pneumonia**.
396
+
397
+ **Clinical team may consider:**
398
+ - Distinguishing bacterial vs viral pneumonia
399
+ - Correlating with:
400
+ - Fever, cough, sputum
401
+ - Pleuritic chest pain
402
+ - Shortness of breath
403
+ - Laboratory tests (WBC, CRP, cultures)
404
+ - Guideline-based antibiotic or supportive therapy if confirmed
405
+
406
+ This tool aims to **prioritize patients** for rapid review, especially
407
+ where waiting times are long.
408
+ """
409
+ else:
410
+ interpretation += """
411
+ ### ⚠️ Possible Pneumonia
412
+
413
+ The chest X-ray may show **subtle or early pneumonia-like changes**.
414
+
415
+ **Clinical suggestions:**
416
+ - Evaluate symptoms and vital signs
417
+ - Consider repeat imaging or further labs
418
+ - Use local pneumonia treatment guidelines if diagnosis is confirmed
419
+ """
420
+
421
+ elif pred_label == "COVID-19":
422
+ if confidence >= 85:
423
+ interpretation += """
424
+ ### ⚠️ High-Confidence COVID-19 Pneumonia Pattern
425
+
426
+ The AI sees a pattern often associated with **COVID-19 pneumonia**.
427
+
428
+ **Clinical next steps typically include:**
429
+ - **SARS-CoV-2 testing** (RT-PCR or antigen)
430
+ - Isolation and infection prevention
431
+ - Monitoring oxygen saturation (SpO2)
432
+ - Urgent care if:
433
+ - SpO2 < 94%
434
+ - Respiratory distress
435
+ - Persistent chest pain or confusion
436
+
437
+ Imaging alone **cannot confirm COVID-19**. Lab testing + clinical judgment are essential.
438
+ """
439
+ else:
440
+ interpretation += """
441
+ ### ⚠️ Possible COVID-19 Pattern
442
+
443
+ There are features that *could* be compatible with COVID-19, but the AI is not very certain.
444
+
445
+ **Clinical suggestions:**
446
+ - Follow local COVID-19 testing protocols
447
+ - Use symptoms and exposure history
448
+ - Monitor for deterioration and hypoxia
449
+ """
450
+
451
+ else: # Normal
452
+ if confidence >= 85:
453
+ interpretation += """
454
+ ### βœ… High-Confidence "No Major Abnormality" Pattern
455
+
456
+ The model does **not** see strong evidence of TB, pneumonia, or COVID-19.
457
+
458
+ This may support a **normal chest X-ray**, but:
459
+
460
+ - Early disease can be radiographically subtle
461
+ - Some lung or cardiac diseases are **not detectable** here
462
+ - Symptoms always override AI reassurance
463
+
464
+ If a patient is symptomatic, clinical review is still required.
465
+ """
466
+ else:
467
+ interpretation += """
468
+ ### ⚠️ Likely Normal, But With Low Confidence
469
+
470
+ The model leans toward a **normal** study, but uncertainty is higher than usual.
471
+
472
+ - If the patient is unwell, treat this as **inconclusive**
473
+ - Consider follow-up imaging or alternative diagnostics
474
+ """
475
+
476
+ interpretation += """
477
+ ---
478
+
479
+ ## 🌍 Built to Help Humanity
480
+
481
+ This AI system is being developed to:
482
+
483
+ - Support **front-line clinicians** in low-resource and high-burden regions
484
+ - Provide an **energy-efficient (Adaptive Sparse Training)** screening assistant
485
+ - Help triage patients when **radiologists are not immediately available**
486
+
487
+ It is **open research**, not a commercial product, and we welcome
488
+ feedback from clinicians, researchers, and public health teams.
489
+
490
+ ---
491
+
492
+ ## ⚠️ Critical Medical Disclaimer
493
+
494
+ - This is a **screening and research tool only** – **NOT** an FDA/CE approved device.
495
+ - It does **not** replace radiologists, pulmonologists, or infectious disease experts.
496
+ - All decisions about diagnosis and treatment must be made by qualified clinicians.
497
+ - Gold-standard confirmation remains:
498
+ - **TB** – sputum tests, culture, GeneXpert, TB-PCR
499
+ - **Pneumonia** – full clinical assessment + labs/imaging
500
+ - **COVID-19** – RT-PCR / validated antigen testing
501
+
502
+ If there is any doubt, always follow local clinical guidelines and consult a specialist.
503
+ """
504
+
505
+ return interpretation
506
+
507
+ # ============================================================================
508
+ # Gradio Interface
509
+ # ============================================================================
510
+
511
+ custom_css = """
512
+ #main-container {
513
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
514
+ padding: 20px;
515
+ }
516
+ #title {
517
+ text-align: center;
518
+ color: white;
519
+ font-size: 2.5em;
520
+ font-weight: 800;
521
+ margin-bottom: 10px;
522
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.35);
523
+ }
524
+ #subtitle {
525
+ text-align: center;
526
+ color: #f5f5ff;
527
+ font-size: 1.1em;
528
+ margin-bottom: 12px;
529
+ }
530
+ #mission {
531
+ text-align: center;
532
+ color: #ffffff;
533
+ font-size: 0.95em;
534
+ margin-bottom: 24px;
535
+ padding: 14px 18px;
536
+ background: rgba(0,0,0,0.15);
537
+ border-radius: 12px;
538
+ backdrop-filter: blur(12px);
539
+ }
540
+ #stats {
541
+ text-align: center;
542
+ color: #fff;
543
+ font-size: 0.95em;
544
+ margin-bottom: 30px;
545
+ padding: 12px 16px;
546
+ background: rgba(255,255,255,0.08);
547
+ border-radius: 10px;
548
+ }
549
+ .gradio-container {
550
+ font-family: "Inter", system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
551
+ }
552
+ #upload-box {
553
+ border: 3px dashed #667eea;
554
+ border-radius: 15px;
555
+ padding: 20px;
556
+ background: rgba(255,255,255,0.97);
557
+ }
558
+ #results-box {
559
+ background: white;
560
+ border-radius: 15px;
561
+ padding: 20px;
562
+ box-shadow: 0 4px 12px rgba(0,0,0,0.12);
563
+ }
564
+ .output-image {
565
+ border-radius: 10px;
566
+ box-shadow: 0 2px 6px rgba(0,0,0,0.15);
567
+ }
568
+ footer {
569
+ text-align: center;
570
+ margin-top: 30px;
571
+ color: white;
572
+ font-size: 0.9em;
573
+ }
574
+ """
575
+
576
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
577
+ gr.HTML(
578
+ """
579
+ <div id="main-container">
580
+ <div id="title">🫁 Global Chest X-Ray Screening AI</div>
581
+ <div id="subtitle">
582
+ 4-Class detection β€’ Explainable AI β€’ Adaptive Sparse Training
583
+ </div>
584
+ <div id="mission">
585
+ <b>Mission:</b> Support clinicians and patients worldwide – especially in
586
+ low-resource, high-burden regions – by providing an energy-efficient AI
587
+ assistant for chest X-ray screening. This is a <b>second opinion</b> tool,
588
+ not a replacement for human experts.
589
+ </div>
590
+ <div id="stats">
591
+ <b>Trained on 4 classes:</b> Normal β€’ Tuberculosis β€’ Pneumonia β€’ COVID-19<br/>
592
+ <b>Energy-efficient:</b> Adaptive Sparse Training (AST) – ~89% compute savings (research setting)<br/>
593
+ <b>Use case:</b> Triage & screening support for TB, pneumonia, and COVID-19 programs
594
+ </div>
595
+ </div>
596
+ """
597
+ )
598
+
599
+ with gr.Row():
600
+ with gr.Column(scale=1, elem_id="upload-box"):
601
+ gr.Markdown("## πŸ“€ Upload Chest X-Ray")
602
+ image_input = gr.Image(
603
+ type="pil",
604
+ label="Upload X-Ray Image (PA or AP view)",
605
+ elem_classes="output-image",
606
+ )
607
+
608
+ show_gradcam = gr.Checkbox(
609
+ value=True,
610
+ label="Enable Grad-CAM (Explainable AI)",
611
+ info="Shows which lung regions the model is focusing on.",
612
+ )
613
+
614
+ analyze_btn = gr.Button("πŸ”¬ Analyze X-Ray", variant="primary", size="lg")
615
+
616
+ gr.Markdown(
617
+ """
618
+ ### πŸ“‹ Supported Images
619
+ - Chest X-rays (PA or AP view)
620
+ - PNG / JPG / JPEG
621
+ - Grayscale or RGB
622
+
623
+ ### πŸ’‘ Designed For
624
+ - TB & pneumonia screening programs
625
+ - Remote / low-resource clinics
626
+ - Educational and research use
627
+
628
+ > ⚠️ Always combine AI output with clinical judgment and lab tests.
629
+ """
630
+ )
631
+
632
+ with gr.Column(scale=2, elem_id="results-box"):
633
+ gr.Markdown("## πŸ“Š AI Analysis Results")
634
+
635
+ with gr.Row():
636
+ prob_output = gr.Label(
637
+ label="Prediction Confidence (per class)",
638
+ num_top_classes=4,
639
+ )
640
+
641
+ with gr.Tabs():
642
+ with gr.Tab("Original (Annotated)"):
643
+ original_output = gr.Image(
644
+ label="Annotated X-Ray",
645
+ elem_classes="output-image",
646
+ )
647
+
648
+ with gr.Tab("Grad-CAM Heatmap"):
649
+ gradcam_output = gr.Image(
650
+ label="Model Attention Heatmap",
651
+ elem_classes="output-image",
652
+ )
653
+
654
+ with gr.Tab("Overlay"):
655
+ overlay_output = gr.Image(
656
+ label="Explainable AI Overlay",
657
+ elem_classes="output-image",
658
+ )
659
+
660
+ interpretation_output = gr.Markdown(label="Clinical-Style Interpretation")
661
+
662
+ gr.Markdown("## πŸ“ Example X-Rays (for testing only – not real patients)")
663
+ gr.Examples(
664
+ examples=[
665
+ ["examples/normal.png"],
666
+ ["examples/tb.png"],
667
+ ["examples/pneumonia.png"],
668
+ ["examples/covid.png"],
669
+ ],
670
+ inputs=image_input,
671
+ label="Click an example to load it into the app",
672
+ )
673
+
674
+ analyze_btn.click(
675
+ fn=predict_chest_xray,
676
+ inputs=[image_input, show_gradcam],
677
+ outputs=[
678
+ prob_output,
679
+ original_output,
680
+ gradcam_output,
681
+ overlay_output,
682
+ interpretation_output,
683
+ ],
684
+ )
685
+
686
+ gr.HTML(
687
+ """
688
+ <footer>
689
+ <p>
690
+ <b>🫁 Global Chest X-Ray Screening with Adaptive Sparse Training</b><br/>
691
+ Built as open research to support clinicians and public health teams worldwide.<br/>
692
+ Not a medical device β€’ Not for autonomous diagnosis or treatment decisions.
693
+ </p>
694
+ <p style="font-size: 0.8em; margin-top: 12px;">
695
+ ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is for research and educational use only.
696
+ All findings must be confirmed by qualified medical professionals using
697
+ appropriate clinical and laboratory standards.
698
+ </p>
699
+ </footer>
700
+ """
701
+ )
702
+
703
+ if __name__ == "__main__":
704
+ demo.launch(
705
+ share=False,
706
+ server_name="0.0.0.0",
707
+ server_port=7860,
708
+ show_error=True,
709
+ )