Energy-efficient: Adaptive Sparse Training (AST) – ~89% compute savings (research setting)
Use case: Triage & screening support for TB, pneumonia, and COVID-19 programs
""" 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training 4-Class Screening: Normal, Tuberculosis, Pneumonia, COVID-19 Mission: This open research tool is being built to help humanity – especially patients and clinicians in low-resource settings – by providing energy-efficient, explainable AI support for chest X-ray screening. It is a digital second opinion, NOT a replacement for radiologists or doctors. """ import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np import cv2 import matplotlib.pyplot as plt from pathlib import Path import io # ============================================================================ # Model Setup # ============================================================================ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_efficientnet_model(): """ Build EfficientNet-B2 and load your working 4-class best.pt checkpoint. We intentionally keep this simple and very close to the version you already confirmed is working, to avoid shape-mismatch issues. """ # Base architecture: EfficientNet-B2 model = models.efficientnet_b2(weights=None) in_features = model.classifier[1].in_features model.classifier[1] = nn.Linear(in_features, 4) # 4 classes # Where we expect your weights to live candidate_paths = [ Path("checkpoints/best.pt"), # HF Space path (from your screenshot) Path("best.pt"), # fallback for local runs ] last_error = None for ckpt_path in candidate_paths: if not ckpt_path.exists(): print(f"⚠️ Checkpoint not found at {ckpt_path}") continue try: print(f"🔍 Loading weights from: {ckpt_path}") state = torch.load(ckpt_path, map_location=device) # If it comes from a training script with wrappers if isinstance(state, dict): if "model_state_dict" in state: state = state["model_state_dict"] elif "state_dict" in state: state = state["state_dict"] # This is the same idea as your original working call missing, unexpected = model.load_state_dict(state, strict=False) if missing or unexpected: print(f" ⚠️ Non-critical keys - missing: {missing}, unexpected: {unexpected}") print(f"✅ Model weights successfully loaded from {ckpt_path}") model.to(device) model.eval() return model except Exception as e: print(f"❌ Could not load from {ckpt_path}: {e}") last_error = e raise RuntimeError( "Could not load EfficientNet-B2 4-class weights from any known path.\n" f"Last error: {last_error}" ) model = load_efficientnet_model() # Classes CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"] CLASS_COLORS = { "Normal": "#2ecc71", # Green "Tuberculosis": "#e74c3c", # Red "Pneumonia": "#f39c12", # Orange "COVID-19": "#9b59b6", # Purple } # Image preprocessing transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], ), ] ) # ============================================================================ # Grad-CAM Implementation # ============================================================================ class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None def save_gradient(grad): self.gradients = grad def save_activation(module, input, output): self.activations = output.detach() # Forward hook: store activations target_layer.register_forward_hook(save_activation) # Backward hook: store gradients target_layer.register_full_backward_hook( lambda m, grad_in, grad_out: save_gradient(grad_out[0]) ) def generate(self, input_image, target_class=None): output = self.model(input_image) if target_class is None: target_class = output.argmax(dim=1) self.model.zero_grad() one_hot = torch.zeros_like(output) one_hot[0][target_class] = 1 output.backward(gradient=one_hot, retain_graph=True) if self.gradients is None or self.activations is None: return None, output # Global average pooling over gradients weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = torch.relu(cam) cam = cam.squeeze().cpu().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam, output # Setup Grad-CAM on the last feature layer target_layer = model.features[-1] grad_cam = GradCAM(model, target_layer) # ============================================================================ # Prediction & Visualization # ============================================================================ def predict_chest_xray(image, show_gradcam=True): """ Predict disease class from chest X-ray with Grad-CAM visualization. Returns: - class probabilities dict - annotated original image - Grad-CAM heatmap image - overlay image - markdown clinical interpretation """ if image is None: return None, None, None, None, "Please upload a chest X-ray." # Convert to PIL if needed if isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") # Keep original for visualization original_img = image.copy() # Preprocess input_tensor = transform(image).unsqueeze(0).to(device) # Forward + optional Grad-CAM with torch.set_grad_enabled(show_gradcam): if show_gradcam: cam, output = grad_cam.generate(input_tensor) else: cam = None output = model(input_tensor) # Probabilities probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy() prob_sum = float(np.sum(probs)) if not (0.99 <= prob_sum <= 1.01): print(f"⚠️ Probability sum is {prob_sum:.4f}, expected ~1.0 – check model weights.") pred_class = int(output.argmax(dim=1).item()) pred_label = CLASSES[pred_class] confidence = float(probs[pred_class] * 100.0) # Ensure values between 0–100 results = { CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0))) for i in range(len(CLASSES)) } # Visualizations original_pil = create_original_display(original_img, pred_label, confidence) if cam is not None and show_gradcam: gradcam_viz = create_gradcam_visualization( original_img, cam, pred_label, confidence ) overlay_viz = create_overlay_visualization(original_img, cam) else: gradcam_viz = None overlay_viz = None # Interpretation text interpretation = create_interpretation(pred_label, confidence, results) return results, original_pil, gradcam_viz, overlay_viz, interpretation def create_original_display(image, pred_label, confidence): """Create annotated original image.""" fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(image) ax.axis("off") color = CLASS_COLORS[pred_label] title = f"Prediction: {pred_label}\nConfidence: {confidence:.1f}%" ax.set_title(title, fontsize=16, fontweight="bold", color=color, pad=20) plt.tight_layout() buf = io.BytesIO() plt.savefig( buf, format="png", dpi=150, bbox_inches="tight", facecolor="white", ) plt.close() buf.seek(0) return Image.open(buf) def create_gradcam_visualization(image, cam, pred_label, confidence): """Create Grad-CAM heatmap.""" img_array = np.array(image.resize((224, 224))) cam_resized = cv2.resize(cam, (224, 224)) heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(heatmap) ax.axis("off") ax.set_title( "Attention Heatmap\n(Areas the model focuses on)", fontsize=14, fontweight="bold", pad=20, ) plt.tight_layout() buf = io.BytesIO() plt.savefig( buf, format="png", dpi=150, bbox_inches="tight", facecolor="white", ) plt.close() buf.seek(0) return Image.open(buf) def create_overlay_visualization(image, cam): """Overlay original image and Grad-CAM heatmap.""" img_array = np.array(image.resize((224, 224))) / 255.0 cam_resized = cv2.resize(cam, (224, 224)) heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 overlay = img_array * 0.5 + heatmap * 0.5 overlay = np.clip(overlay, 0, 1) fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(overlay) ax.axis("off") ax.set_title( "Explainable AI Visualization\n(Original + Heatmap)", fontsize=14, fontweight="bold", pad=20, ) plt.tight_layout() buf = io.BytesIO() plt.savefig( buf, format="png", dpi=150, bbox_inches="tight", facecolor="white", ) plt.close() buf.seek(0) return Image.open(buf) def create_interpretation(pred_label, confidence, results): """ Clinical-style interpretation text with strong global-health framing and strict medical disclaimer. """ interpretation = f""" ## 🫁 AI Chest X-Ray Screening – Global Health Edition This tool is part of an open effort to **support clinicians and patients worldwide**, especially in places where radiologists are scarce. --- ## 🔬 Analysis Summary **Predicted class:** **{pred_label}** **Model confidence:** **{confidence:.1f}%** ### Probability Breakdown - 🟢 Normal: **{results['Normal']:.1f}%** - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%** - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%** - 🟣 COVID-19: **{results['COVID-19']:.1f}%** --- """ # Disease-specific details if pred_label == "Tuberculosis": if confidence >= 85: interpretation += """ ### ⚠️ High-Confidence Tuberculosis Pattern Detected The AI model has found features strongly suggestive of **pulmonary tuberculosis (TB)**. **Suggested next steps for a clinical team (NOT automatic orders):** 1. Correlate with symptoms: - Cough > 2 weeks - Night sweats, fever - Weight loss - Hemoptysis (coughing blood) 2. Order **confirmatory TB tests**: - Sputum smear / culture - GeneXpert MTB/RIF or TB-PCR 3. Consider **isolation** and **contact screening** if TB is suspected. 4. Evaluate HIV status and comorbidities according to local guidelines. ➡️ This system is designed to **support TB programs** in low-resource settings, where early triage can save lives. """ else: interpretation += """ ### ⚠️ Possible Tuberculosis Features The model sees **TB-like patterns**, but confidence is moderate. **Recommended clinical follow-up (not automatic diagnosis):** - Detailed history and physical examination - Evaluate TB risk factors and symptoms - Consider sputum-based TB testing - Repeat imaging or CT if clinically indicated """ elif pred_label == "Pneumonia": if confidence >= 85: interpretation += """ ### ⚠️ High-Confidence Pneumonia Pattern The model detects findings consistent with **pneumonia**. **Clinical team may consider:** - Distinguishing bacterial vs viral pneumonia - Correlating with: - Fever, cough, sputum - Pleuritic chest pain - Shortness of breath - Laboratory tests (WBC, CRP, cultures) - Guideline-based antibiotic or supportive therapy if confirmed This tool aims to **prioritize patients** for rapid review, especially where waiting times are long. """ else: interpretation += """ ### ⚠️ Possible Pneumonia The chest X-ray may show **subtle or early pneumonia-like changes**. **Clinical suggestions:** - Evaluate symptoms and vital signs - Consider repeat imaging or further labs - Use local pneumonia treatment guidelines if diagnosis is confirmed """ elif pred_label == "COVID-19": if confidence >= 85: interpretation += """ ### ⚠️ High-Confidence COVID-19 Pneumonia Pattern The AI sees a pattern often associated with **COVID-19 pneumonia**. **Clinical next steps typically include:** - **SARS-CoV-2 testing** (RT-PCR or antigen) - Isolation and infection prevention - Monitoring oxygen saturation (SpO2) - Urgent care if: - SpO2 < 94% - Respiratory distress - Persistent chest pain or confusion Imaging alone **cannot confirm COVID-19**. Lab testing + clinical judgment are essential. """ else: interpretation += """ ### ⚠️ Possible COVID-19 Pattern There are features that *could* be compatible with COVID-19, but the AI is not very certain. **Clinical suggestions:** - Follow local COVID-19 testing protocols - Use symptoms and exposure history - Monitor for deterioration and hypoxia """ else: # Normal if confidence >= 85: interpretation += """ ### ✅ High-Confidence "No Major Abnormality" Pattern The model does **not** see strong evidence of TB, pneumonia, or COVID-19. This may support a **normal chest X-ray**, but: - Early disease can be radiographically subtle - Some lung or cardiac diseases are **not detectable** here - Symptoms always override AI reassurance If a patient is symptomatic, clinical review is still required. """ else: interpretation += """ ### ⚠️ Likely Normal, But With Low Confidence The model leans toward a **normal** study, but uncertainty is higher than usual. - If the patient is unwell, treat this as **inconclusive** - Consider follow-up imaging or alternative diagnostics """ interpretation += """ --- ## 🌍 Built to Help Humanity This AI system is being developed to: - Support **front-line clinicians** in low-resource and high-burden regions - Provide an **energy-efficient (Adaptive Sparse Training)** screening assistant - Help triage patients when **radiologists are not immediately available** It is **open research**, not a commercial product, and we welcome feedback from clinicians, researchers, and public health teams. --- ## ⚠️ Critical Medical Disclaimer - This is a **screening and research tool only** – **NOT** an FDA/CE approved device. - It does **not** replace radiologists, pulmonologists, or infectious disease experts. - All decisions about diagnosis and treatment must be made by qualified clinicians. - Gold-standard confirmation remains: - **TB** – sputum tests, culture, GeneXpert, TB-PCR - **Pneumonia** – full clinical assessment + labs/imaging - **COVID-19** – RT-PCR / validated antigen testing If there is any doubt, always follow local clinical guidelines and consult a specialist. """ return interpretation # ============================================================================ # Gradio Interface # ============================================================================ custom_css = """ #main-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; } #title { text-align: center; color: white; font-size: 2.5em; font-weight: 800; margin-bottom: 10px; text-shadow: 2px 2px 4px rgba(0,0,0,0.35); } #subtitle { text-align: center; color: #f5f5ff; font-size: 1.1em; margin-bottom: 12px; } #mission { text-align: center; color: #ffffff; font-size: 0.95em; margin-bottom: 24px; padding: 14px 18px; background: rgba(0,0,0,0.15); border-radius: 12px; backdrop-filter: blur(12px); } #stats { text-align: center; color: #fff; font-size: 0.95em; margin-bottom: 30px; padding: 12px 16px; background: rgba(255,255,255,0.08); border-radius: 10px; } .gradio-container { font-family: "Inter", system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; } #upload-box { border: 3px dashed #667eea; border-radius: 15px; padding: 20px; background: rgba(255,255,255,0.97); } #results-box { background: white; border-radius: 15px; padding: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.12); } .output-image { border-radius: 10px; box-shadow: 0 2px 6px rgba(0,0,0,0.15); } footer { text-align: center; margin-top: 30px; color: white; font-size: 0.9em; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.HTML( """