Spaces:
Sleeping
Sleeping
| """ | |
| π« Multi-Class Chest X-Ray Detection with Adaptive Sparse Training | |
| 4-Class Screening: Normal, Tuberculosis, Pneumonia, COVID-19 | |
| Mission: | |
| This open research tool is being built to help humanity β | |
| especially patients and clinicians in low-resource settings β | |
| by providing energy-efficient, explainable AI support for chest | |
| X-ray screening. It is a digital second opinion, NOT a replacement | |
| for radiologists or doctors. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import io | |
| # ============================================================================ | |
| # Model Setup | |
| # ============================================================================ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_efficientnet_model(): | |
| """ | |
| Build EfficientNet-B2 and load your working 4-class best.pt checkpoint. | |
| We intentionally keep this simple and very close to the version you | |
| already confirmed is working, to avoid shape-mismatch issues. | |
| """ | |
| # Base architecture: EfficientNet-B2 | |
| model = models.efficientnet_b2(weights=None) | |
| in_features = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(in_features, 4) # 4 classes | |
| # Where we expect your weights to live | |
| candidate_paths = [ | |
| Path("checkpoints/best.pt"), # HF Space path (from your screenshot) | |
| Path("best.pt"), # fallback for local runs | |
| ] | |
| last_error = None | |
| for ckpt_path in candidate_paths: | |
| if not ckpt_path.exists(): | |
| print(f"β οΈ Checkpoint not found at {ckpt_path}") | |
| continue | |
| try: | |
| print(f"π Loading weights from: {ckpt_path}") | |
| state = torch.load(ckpt_path, map_location=device) | |
| # If it comes from a training script with wrappers | |
| if isinstance(state, dict): | |
| if "model_state_dict" in state: | |
| state = state["model_state_dict"] | |
| elif "state_dict" in state: | |
| state = state["state_dict"] | |
| # This is the same idea as your original working call | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| print(f" β οΈ Non-critical keys - missing: {missing}, unexpected: {unexpected}") | |
| print(f"β Model weights successfully loaded from {ckpt_path}") | |
| model.to(device) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"β Could not load from {ckpt_path}: {e}") | |
| last_error = e | |
| raise RuntimeError( | |
| "Could not load EfficientNet-B2 4-class weights from any known path.\n" | |
| f"Last error: {last_error}" | |
| ) | |
| model = load_efficientnet_model() | |
| # Classes | |
| CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"] | |
| CLASS_COLORS = { | |
| "Normal": "#2ecc71", # Green | |
| "Tuberculosis": "#e74c3c", # Red | |
| "Pneumonia": "#f39c12", # Orange | |
| "COVID-19": "#9b59b6", # Purple | |
| } | |
| # Image preprocessing | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225], | |
| ), | |
| ] | |
| ) | |
| # ============================================================================ | |
| # Grad-CAM Implementation | |
| # ============================================================================ | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| def save_gradient(grad): | |
| self.gradients = grad | |
| def save_activation(module, input, output): | |
| self.activations = output.detach() | |
| # Forward hook: store activations | |
| target_layer.register_forward_hook(save_activation) | |
| # Backward hook: store gradients | |
| target_layer.register_full_backward_hook( | |
| lambda m, grad_in, grad_out: save_gradient(grad_out[0]) | |
| ) | |
| def generate(self, input_image, target_class=None): | |
| output = self.model(input_image) | |
| if target_class is None: | |
| target_class = output.argmax(dim=1) | |
| self.model.zero_grad() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0][target_class] = 1 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| if self.gradients is None or self.activations is None: | |
| return None, output | |
| # Global average pooling over gradients | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
| cam = torch.relu(cam) | |
| cam = cam.squeeze().cpu().numpy() | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| return cam, output | |
| # Setup Grad-CAM on the last feature layer | |
| target_layer = model.features[-1] | |
| grad_cam = GradCAM(model, target_layer) | |
| # ============================================================================ | |
| # Prediction & Visualization | |
| # ============================================================================ | |
| def predict_chest_xray(image, show_gradcam=True): | |
| """ | |
| Predict disease class from chest X-ray with Grad-CAM visualization. | |
| Returns: | |
| - class probabilities dict | |
| - annotated original image | |
| - Grad-CAM heatmap image | |
| - overlay image | |
| - markdown clinical interpretation | |
| """ | |
| if image is None: | |
| return None, None, None, None, "Please upload a chest X-ray." | |
| # Convert to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| # Keep original for visualization | |
| original_img = image.copy() | |
| # Preprocess | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Forward + optional Grad-CAM | |
| with torch.set_grad_enabled(show_gradcam): | |
| if show_gradcam: | |
| cam, output = grad_cam.generate(input_tensor) | |
| else: | |
| cam = None | |
| output = model(input_tensor) | |
| # Probabilities | |
| probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy() | |
| prob_sum = float(np.sum(probs)) | |
| if not (0.99 <= prob_sum <= 1.01): | |
| print(f"β οΈ Probability sum is {prob_sum:.4f}, expected ~1.0 β check model weights.") | |
| pred_class = int(output.argmax(dim=1).item()) | |
| pred_label = CLASSES[pred_class] | |
| confidence = float(probs[pred_class] * 100.0) | |
| # Ensure values between 0β100 | |
| results = { | |
| CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0))) | |
| for i in range(len(CLASSES)) | |
| } | |
| # Visualizations | |
| original_pil = create_original_display(original_img, pred_label, confidence) | |
| if cam is not None and show_gradcam: | |
| gradcam_viz = create_gradcam_visualization( | |
| original_img, cam, pred_label, confidence | |
| ) | |
| overlay_viz = create_overlay_visualization(original_img, cam) | |
| else: | |
| gradcam_viz = None | |
| overlay_viz = None | |
| # Interpretation text | |
| interpretation = create_interpretation(pred_label, confidence, results) | |
| return results, original_pil, gradcam_viz, overlay_viz, interpretation | |
| def create_original_display(image, pred_label, confidence): | |
| """Create annotated original image.""" | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(image) | |
| ax.axis("off") | |
| color = CLASS_COLORS[pred_label] | |
| title = f"Prediction: {pred_label}\nConfidence: {confidence:.1f}%" | |
| ax.set_title(title, fontsize=16, fontweight="bold", color=color, pad=20) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig( | |
| buf, | |
| format="png", | |
| dpi=150, | |
| bbox_inches="tight", | |
| facecolor="white", | |
| ) | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_gradcam_visualization(image, cam, pred_label, confidence): | |
| """Create Grad-CAM heatmap.""" | |
| img_array = np.array(image.resize((224, 224))) | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(heatmap) | |
| ax.axis("off") | |
| ax.set_title( | |
| "Attention Heatmap\n(Areas the model focuses on)", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=20, | |
| ) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig( | |
| buf, | |
| format="png", | |
| dpi=150, | |
| bbox_inches="tight", | |
| facecolor="white", | |
| ) | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_overlay_visualization(image, cam): | |
| """Overlay original image and Grad-CAM heatmap.""" | |
| img_array = np.array(image.resize((224, 224))) / 255.0 | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 | |
| overlay = img_array * 0.5 + heatmap * 0.5 | |
| overlay = np.clip(overlay, 0, 1) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(overlay) | |
| ax.axis("off") | |
| ax.set_title( | |
| "Explainable AI Visualization\n(Original + Heatmap)", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=20, | |
| ) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig( | |
| buf, | |
| format="png", | |
| dpi=150, | |
| bbox_inches="tight", | |
| facecolor="white", | |
| ) | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_interpretation(pred_label, confidence, results): | |
| """ | |
| Clinical-style interpretation text with strong global-health framing | |
| and strict medical disclaimer. | |
| """ | |
| interpretation = f""" | |
| ## π« AI Chest X-Ray Screening β Global Health Edition | |
| This tool is part of an open effort to **support clinicians and patients worldwide**, | |
| especially in places where radiologists are scarce. | |
| --- | |
| ## π¬ Analysis Summary | |
| **Predicted class:** **{pred_label}** | |
| **Model confidence:** **{confidence:.1f}%** | |
| ### Probability Breakdown | |
| - π’ Normal: **{results['Normal']:.1f}%** | |
| - π΄ Tuberculosis: **{results['Tuberculosis']:.1f}%** | |
| - π Pneumonia: **{results['Pneumonia']:.1f}%** | |
| - π£ COVID-19: **{results['COVID-19']:.1f}%** | |
| --- | |
| """ | |
| # Disease-specific details | |
| if pred_label == "Tuberculosis": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### β οΈ High-Confidence Tuberculosis Pattern Detected | |
| The AI model has found features strongly suggestive of **pulmonary tuberculosis (TB)**. | |
| **Suggested next steps for a clinical team (NOT automatic orders):** | |
| 1. Correlate with symptoms: | |
| - Cough > 2 weeks | |
| - Night sweats, fever | |
| - Weight loss | |
| - Hemoptysis (coughing blood) | |
| 2. Order **confirmatory TB tests**: | |
| - Sputum smear / culture | |
| - GeneXpert MTB/RIF or TB-PCR | |
| 3. Consider **isolation** and **contact screening** if TB is suspected. | |
| 4. Evaluate HIV status and comorbidities according to local guidelines. | |
| β‘οΈ This system is designed to **support TB programs** in low-resource settings, | |
| where early triage can save lives. | |
| """ | |
| else: | |
| interpretation += """ | |
| ### β οΈ Possible Tuberculosis Features | |
| The model sees **TB-like patterns**, but confidence is moderate. | |
| **Recommended clinical follow-up (not automatic diagnosis):** | |
| - Detailed history and physical examination | |
| - Evaluate TB risk factors and symptoms | |
| - Consider sputum-based TB testing | |
| - Repeat imaging or CT if clinically indicated | |
| """ | |
| elif pred_label == "Pneumonia": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### β οΈ High-Confidence Pneumonia Pattern | |
| The model detects findings consistent with **pneumonia**. | |
| **Clinical team may consider:** | |
| - Distinguishing bacterial vs viral pneumonia | |
| - Correlating with: | |
| - Fever, cough, sputum | |
| - Pleuritic chest pain | |
| - Shortness of breath | |
| - Laboratory tests (WBC, CRP, cultures) | |
| - Guideline-based antibiotic or supportive therapy if confirmed | |
| This tool aims to **prioritize patients** for rapid review, especially | |
| where waiting times are long. | |
| """ | |
| else: | |
| interpretation += """ | |
| ### β οΈ Possible Pneumonia | |
| The chest X-ray may show **subtle or early pneumonia-like changes**. | |
| **Clinical suggestions:** | |
| - Evaluate symptoms and vital signs | |
| - Consider repeat imaging or further labs | |
| - Use local pneumonia treatment guidelines if diagnosis is confirmed | |
| """ | |
| elif pred_label == "COVID-19": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### β οΈ High-Confidence COVID-19 Pneumonia Pattern | |
| The AI sees a pattern often associated with **COVID-19 pneumonia**. | |
| **Clinical next steps typically include:** | |
| - **SARS-CoV-2 testing** (RT-PCR or antigen) | |
| - Isolation and infection prevention | |
| - Monitoring oxygen saturation (SpO2) | |
| - Urgent care if: | |
| - SpO2 < 94% | |
| - Respiratory distress | |
| - Persistent chest pain or confusion | |
| Imaging alone **cannot confirm COVID-19**. Lab testing + clinical judgment are essential. | |
| """ | |
| else: | |
| interpretation += """ | |
| ### β οΈ Possible COVID-19 Pattern | |
| There are features that *could* be compatible with COVID-19, but the AI is not very certain. | |
| **Clinical suggestions:** | |
| - Follow local COVID-19 testing protocols | |
| - Use symptoms and exposure history | |
| - Monitor for deterioration and hypoxia | |
| """ | |
| else: # Normal | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### β High-Confidence "No Major Abnormality" Pattern | |
| The model does **not** see strong evidence of TB, pneumonia, or COVID-19. | |
| This may support a **normal chest X-ray**, but: | |
| - Early disease can be radiographically subtle | |
| - Some lung or cardiac diseases are **not detectable** here | |
| - Symptoms always override AI reassurance | |
| If a patient is symptomatic, clinical review is still required. | |
| """ | |
| else: | |
| interpretation += """ | |
| ### β οΈ Likely Normal, But With Low Confidence | |
| The model leans toward a **normal** study, but uncertainty is higher than usual. | |
| - If the patient is unwell, treat this as **inconclusive** | |
| - Consider follow-up imaging or alternative diagnostics | |
| """ | |
| interpretation += """ | |
| --- | |
| ## π Built to Help Humanity | |
| This AI system is being developed to: | |
| - Support **front-line clinicians** in low-resource and high-burden regions | |
| - Provide an **energy-efficient (Adaptive Sparse Training)** screening assistant | |
| - Help triage patients when **radiologists are not immediately available** | |
| It is **open research**, not a commercial product, and we welcome | |
| feedback from clinicians, researchers, and public health teams. | |
| --- | |
| ## β οΈ Critical Medical Disclaimer | |
| - This is a **screening and research tool only** β **NOT** an FDA/CE approved device. | |
| - It does **not** replace radiologists, pulmonologists, or infectious disease experts. | |
| - All decisions about diagnosis and treatment must be made by qualified clinicians. | |
| - Gold-standard confirmation remains: | |
| - **TB** β sputum tests, culture, GeneXpert, TB-PCR | |
| - **Pneumonia** β full clinical assessment + labs/imaging | |
| - **COVID-19** β RT-PCR / validated antigen testing | |
| If there is any doubt, always follow local clinical guidelines and consult a specialist. | |
| """ | |
| return interpretation | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| custom_css = """ | |
| #main-container { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 20px; | |
| } | |
| #title { | |
| text-align: center; | |
| color: white; | |
| font-size: 2.5em; | |
| font-weight: 800; | |
| margin-bottom: 10px; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.35); | |
| } | |
| #subtitle { | |
| text-align: center; | |
| color: #f5f5ff; | |
| font-size: 1.1em; | |
| margin-bottom: 12px; | |
| } | |
| #mission { | |
| text-align: center; | |
| color: #ffffff; | |
| font-size: 0.95em; | |
| margin-bottom: 24px; | |
| padding: 14px 18px; | |
| background: rgba(0,0,0,0.15); | |
| border-radius: 12px; | |
| backdrop-filter: blur(12px); | |
| } | |
| #stats { | |
| text-align: center; | |
| color: #fff; | |
| font-size: 0.95em; | |
| margin-bottom: 30px; | |
| padding: 12px 16px; | |
| background: rgba(255,255,255,0.08); | |
| border-radius: 10px; | |
| } | |
| .gradio-container { | |
| font-family: "Inter", system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; | |
| } | |
| #upload-box { | |
| border: 3px dashed #667eea; | |
| border-radius: 15px; | |
| padding: 20px; | |
| background: rgba(255,255,255,0.97); | |
| } | |
| #results-box { | |
| background: white; | |
| border-radius: 15px; | |
| padding: 20px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.12); | |
| } | |
| .output-image { | |
| border-radius: 10px; | |
| box-shadow: 0 2px 6px rgba(0,0,0,0.15); | |
| } | |
| footer { | |
| text-align: center; | |
| margin-top: 30px; | |
| color: white; | |
| font-size: 0.9em; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML( | |
| """ | |
| <div id="main-container"> | |
| <div id="title">π« Global Chest X-Ray Screening AI</div> | |
| <div id="subtitle"> | |
| 4-Class detection β’ Explainable AI β’ Adaptive Sparse Training | |
| </div> | |
| <div id="mission"> | |
| <b>Mission:</b> Support clinicians and patients worldwide β especially in | |
| low-resource, high-burden regions β by providing an energy-efficient AI | |
| assistant for chest X-ray screening. This is a <b>second opinion</b> tool, | |
| not a replacement for human experts. | |
| </div> | |
| <div id="stats"> | |
| <b>Trained on 4 classes:</b> Normal β’ Tuberculosis β’ Pneumonia β’ COVID-19<br/> | |
| <b>Energy-efficient:</b> Adaptive Sparse Training (AST) β ~89% compute savings (research setting)<br/> | |
| <b>Use case:</b> Triage & screening support for TB, pneumonia, and COVID-19 programs | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="upload-box"): | |
| gr.Markdown("## π€ Upload Chest X-Ray") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload X-Ray Image (PA or AP view)", | |
| elem_classes="output-image", | |
| ) | |
| show_gradcam = gr.Checkbox( | |
| value=True, | |
| label="Enable Grad-CAM (Explainable AI)", | |
| info="Shows which lung regions the model is focusing on.", | |
| ) | |
| analyze_btn = gr.Button("π¬ Analyze X-Ray", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| ### π Supported Images | |
| - Chest X-rays (PA or AP view) | |
| - PNG / JPG / JPEG | |
| - Grayscale or RGB | |
| ### π‘ Designed For | |
| - TB & pneumonia screening programs | |
| - Remote / low-resource clinics | |
| - Educational and research use | |
| > β οΈ Always combine AI output with clinical judgment and lab tests. | |
| """ | |
| ) | |
| with gr.Column(scale=2, elem_id="results-box"): | |
| gr.Markdown("## π AI Analysis Results") | |
| with gr.Row(): | |
| prob_output = gr.Label( | |
| label="Prediction Confidence (per class)", | |
| num_top_classes=4, | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Original (Annotated)"): | |
| original_output = gr.Image( | |
| label="Annotated X-Ray", | |
| elem_classes="output-image", | |
| ) | |
| with gr.Tab("Grad-CAM Heatmap"): | |
| gradcam_output = gr.Image( | |
| label="Model Attention Heatmap", | |
| elem_classes="output-image", | |
| ) | |
| with gr.Tab("Overlay"): | |
| overlay_output = gr.Image( | |
| label="Explainable AI Overlay", | |
| elem_classes="output-image", | |
| ) | |
| interpretation_output = gr.Markdown(label="Clinical-Style Interpretation") | |
| gr.Markdown("## π Example X-Rays (for testing only β not real patients)") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/normal.png"], | |
| ["examples/tb.png"], | |
| ["examples/pneumonia.png"], | |
| ["examples/covid.png"], | |
| ], | |
| inputs=image_input, | |
| label="Click an example to load it into the app", | |
| ) | |
| analyze_btn.click( | |
| fn=predict_chest_xray, | |
| inputs=[image_input, show_gradcam], | |
| outputs=[ | |
| prob_output, | |
| original_output, | |
| gradcam_output, | |
| overlay_output, | |
| interpretation_output, | |
| ], | |
| ) | |
| gr.HTML( | |
| """ | |
| <footer> | |
| <p> | |
| <b>π« Global Chest X-Ray Screening with Adaptive Sparse Training</b><br/> | |
| Built as open research to support clinicians and public health teams worldwide.<br/> | |
| Not a medical device β’ Not for autonomous diagnosis or treatment decisions. | |
| </p> | |
| <p style="font-size: 0.8em; margin-top: 12px;"> | |
| β οΈ <b>MEDICAL DISCLAIMER:</b> This tool is for research and educational use only. | |
| All findings must be confirmed by qualified medical professionals using | |
| appropriate clinical and laboratory standards. | |
| </p> | |
| </footer> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) | |