Spaces:
Sleeping
Sleeping
""" | |
Chest X-ray Classification App with Multiple Models | |
Integrates multiple models with unified Grad-CAM visualization | |
""" | |
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
import torchxrayvision as xrv | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from torchvision import models | |
import torch.nn as nn | |
import numpy | |
import torch.serialization | |
import matplotlib.pyplot as plt | |
import cv2 | |
import matplotlib | |
matplotlib.use('Agg') # Use non-interactive backend | |
# Import the custom EfficientNet-B3 model from separate module | |
# This is actually the DannyNet model but renamed for consistency | |
import efficientnet_b3_custom | |
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar]) | |
# Define the available models - removed parenthetical descriptors | |
MODELS = { | |
"DenseNet121": "densenet121-res224-nih", | |
"EfficientNet-B3": "efficientnet_b3_custom", # This is the DannyNet model | |
"EfficientNet-B3 O": "efficientnet_b3", | |
"EfficientNet-B0": "efficientnet_b0" | |
} | |
# NIH ChestX-ray14 pathologies | |
NIH_PATHOLOGIES = [ | |
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', | |
'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', | |
'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia' | |
] | |
# Cache for loaded models | |
loaded_models = {} | |
# For Grad-CAM | |
activation = {} | |
gradient = {} | |
# Hook functions for Grad-CAM | |
def get_activation(name): | |
def hook(model, input, output): | |
activation[name] = output # Keep gradient tracking | |
return hook | |
def get_gradient(name): | |
def hook(grad): | |
gradient[name] = grad | |
return hook | |
def load_model(model_name): | |
if model_name in loaded_models: | |
return loaded_models[model_name] | |
model_type = MODELS[model_name] | |
try: | |
if model_name == "DenseNet121": | |
# Load DenseNet121 from TorchXRayVision | |
model = xrv.models.DenseNet(weights="densenet121-res224-nih") | |
model.eval() | |
loaded_models[model_name] = model | |
return model | |
elif model_type == "efficientnet_b3_custom": | |
# Load the custom EfficientNet-B3 model (actually DannyNet) | |
model_path = "dannynet-55-best_model_20250422-211522.pth" | |
if os.path.exists(model_path): | |
model = efficientnet_b3_custom.load_model(model_path, device='cpu') | |
print(f"Successfully loaded EfficientNet-B3 from {model_path}") | |
loaded_models[model_name] = model | |
return model | |
else: | |
print(f"Model file not found: {model_path}") | |
print("Please place the model file in the same directory as this script.") | |
return None | |
elif model_type == "efficientnet_b3": | |
# Import EfficientNet dynamically to avoid dependency issues | |
try: | |
from efficientnet_pytorch import EfficientNet | |
model = EfficientNet.from_name('efficientnet-b3', num_classes=14) | |
except ImportError: | |
# Fallback to torchvision if efficientnet_pytorch is not available | |
model = models.efficientnet_b3(pretrained=False) | |
num_ftrs = model.classifier[1].in_features | |
model.classifier[1] = nn.Linear(num_ftrs, 14) | |
# Load your trained weights | |
model_path = os.path.join("weights", "best_model_b3.pt") | |
if os.path.exists(model_path): | |
# Explicitly set weights_only=False for PyTorch 2.6+ compatibility | |
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
if "model_state_dict" in checkpoint: | |
model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
else: | |
model.load_state_dict(checkpoint, strict=False) | |
print(f"Successfully loaded EfficientNet-B3 Original from {model_path}") | |
else: | |
print(f"Model file not found: {model_path}") | |
print("Using a new model instance. Please place your trained model in the weights directory.") | |
elif model_type == "efficientnet_b0": | |
# Import EfficientNet dynamically to avoid dependency issues | |
try: | |
from efficientnet_pytorch import EfficientNet | |
model = EfficientNet.from_name('efficientnet-b0', num_classes=14) | |
except ImportError: | |
# Fallback to torchvision if efficientnet_pytorch is not available | |
model = models.efficientnet_b0(pretrained=False) | |
num_ftrs = model.classifier[1].in_features | |
model.classifier[1] = nn.Linear(num_ftrs, 14) | |
# Load your trained weights | |
model_path = os.path.join("weights", "best_model_b0.pt") | |
if os.path.exists(model_path): | |
# Explicitly set weights_only=False for PyTorch 2.6+ compatibility | |
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
if "model_state_dict" in checkpoint: | |
model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
else: | |
model.load_state_dict(checkpoint, strict=False) | |
print(f"Successfully loaded EfficientNet-B0 from {model_path}") | |
else: | |
print(f"Model file not found: {model_path}") | |
print("Using a new model instance. Please place your trained model in the weights directory.") | |
model.eval() | |
loaded_models[model_name] = model | |
return model | |
except Exception as e: | |
print(f"Error loading model {model_name}: {e}") | |
return None | |
def preprocess_image_densenet(img): | |
"""Preprocess an image for the DenseNet model.""" | |
# Convert to grayscale if it's a color image | |
if len(img.shape) > 2: | |
img = img.mean(2) | |
# Normalize the image | |
img = xrv.datasets.normalize(img, 255) | |
# Add channel dimension | |
if len(img.shape) == 2: | |
img = img[None, ...] | |
return img | |
def preprocess_image_efficientnet(img, img_size=224): | |
"""Preprocess an image for the EfficientNet models.""" | |
# Convert to PIL Image if it's a numpy array | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
# Ensure image is in RGB mode | |
img = img.convert('RGB') | |
# Define preprocessing transforms | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Apply transforms | |
img_tensor = transform(img) | |
return img_tensor | |
# Unified Grad-CAM implementation for all models | |
def compute_unified_gradcam(model, model_name, img_tensor, target_class_idx): | |
""" | |
Compute Grad-CAM using a unified approach for all models | |
This ensures consistent visualization style across all models | |
""" | |
# For the custom EfficientNet-B3 model (DannyNet), use its dedicated function | |
if MODELS[model_name] == "efficientnet_b3_custom": | |
return efficientnet_b3_custom.compute_gradcam(model, img_tensor, target_class_idx) | |
# For other models, implement a similar approach to ensure consistent results | |
# Register hooks | |
activation.clear() | |
gradient.clear() | |
# Select appropriate target layer based on model type | |
if model_name == "DenseNet121": | |
target_layer = model.features.denseblock3 | |
elif "EfficientNet" in model_name: | |
if hasattr(model, '_blocks'): | |
middle_idx = len(model._blocks) // 2 | |
target_layer = model._blocks[middle_idx] | |
else: | |
middle_idx = len(model.features) // 2 | |
target_layer = model.features[middle_idx] | |
else: | |
# Default to last feature layer | |
target_layer = model.features[-1] | |
# Register forward hook | |
handle = target_layer.register_forward_hook(get_activation('target_layer')) | |
# Ensure input tensor requires gradients | |
img_tensor_for_gradcam = img_tensor.clone().requires_grad_(True) | |
# Forward pass | |
model.zero_grad() | |
# Handle different model output formats | |
if model_name == "DenseNet121": | |
output = model(img_tensor_for_gradcam.unsqueeze(0)) | |
else: | |
output = model(img_tensor_for_gradcam.unsqueeze(0)) | |
output = torch.sigmoid(output) | |
# Target for backprop | |
if target_class_idx is not None: | |
score = output[0, target_class_idx] | |
else: | |
score, _ = output.max(dim=1) | |
score = score[0] | |
# Backward pass with retain_graph to avoid errors | |
score.backward(retain_graph=True) | |
# Clean up hook | |
handle.remove() | |
# Get activations | |
if 'target_layer' not in activation: | |
print("No activation captured") | |
return None | |
activations = activation['target_layer'] | |
# Try different approaches to get gradients | |
try: | |
# Get gradients using autograd | |
gradients = torch.autograd.grad(score, activations, | |
create_graph=True, retain_graph=True)[0] | |
except Exception as e: | |
print(f"Gradient calculation failed: {e}") | |
# Create dummy gradients as fallback | |
gradients = torch.ones_like(activations) | |
# Use global average pooling with absolute values for better feature highlighting | |
pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3]) | |
# Weight activation maps with gradients | |
for i in range(activations.size(1)): | |
activations[:, i, :, :] *= pooled_gradients[i] | |
# Sum along channels for final heatmap | |
heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy() | |
# ReLU on the heatmap | |
heatmap = np.maximum(heatmap, 0) | |
# Apply gamma correction to enhance contrast | |
gamma = 0.7 # Values less than 1 enhance bright regions | |
heatmap = np.power(heatmap, gamma) | |
# Normalize heatmap | |
if np.max(heatmap) > 0: | |
heatmap = heatmap / np.max(heatmap) | |
# Apply threshold to remove noise | |
threshold = 0.2 # Only keep values above 20% of max | |
heatmap[heatmap < threshold] = 0 | |
# Re-normalize after thresholding | |
if np.max(heatmap) > 0: | |
heatmap = heatmap / np.max(heatmap) | |
# Resize to 224x224 | |
heatmap = cv2.resize(heatmap, (224, 224)) | |
return heatmap | |
# Unified Grad-CAM overlay function for all models | |
def apply_unified_gradcam(original_img, heatmap, alpha=0.6): | |
""" | |
Apply Grad-CAM heatmap to the original image using a unified approach | |
This ensures consistent visualization style across all models | |
""" | |
# Convert to numpy if it's a PIL Image | |
if isinstance(original_img, Image.Image): | |
original_img = np.array(original_img) | |
# Resize original image to 224x224 | |
original_img = cv2.resize(original_img, (224, 224)) | |
# Convert original image to RGB if it's grayscale | |
if len(original_img.shape) == 2: | |
original_img = np.stack([original_img] * 3, axis=2) | |
elif len(original_img.shape) == 3 and original_img.shape[2] == 1: | |
original_img = np.concatenate([original_img] * 3, axis=2) | |
# Convert heatmap to uint8 before applying median blur | |
heatmap_uint8 = np.uint8(heatmap * 255) | |
heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7) | |
# Convert back to float in range [0,1] | |
heatmap = heatmap_blurred.astype(float) / 255.0 | |
# Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization | |
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT) | |
# Convert to RGB if needed | |
if len(original_img.shape) == 3 and original_img.shape[2] == 3: | |
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
# Create a copy of the original image for overlay | |
original_img_float = original_img.astype(float) | |
# Superimpose heatmap on original image | |
superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5) | |
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8) | |
# Add contour lines for the most significant regions | |
binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255 | |
contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1) | |
return superimposed_img | |
def predict_with_gradcam(image, model_name, confidence_threshold=0.5): | |
"""Make predictions and generate Grad-CAM visualization.""" | |
if image is None: | |
return None, "Please upload an image." | |
try: | |
# Create weights directory if it doesn't exist (for EfficientNet models) | |
if "EfficientNet" in model_name: | |
os.makedirs("weights", exist_ok=True) | |
# Load the model | |
model = load_model(model_name) | |
if model is None: | |
return None, f"Failed to load model {model_name}. Please check the console for details." | |
# Save original image for visualization | |
original_img = np.array(image).copy() | |
# Process based on model type | |
if model_name == "DenseNet121": | |
# Read and preprocess the image for DenseNet from TorchXRayVision | |
img = np.array(Image.fromarray(image).convert('RGB')) | |
img_processed = preprocess_image_densenet(img) | |
# Create transforms | |
transform = transforms.Compose([ | |
xrv.datasets.XRayCenterCrop(), | |
xrv.datasets.XRayResizer(224) | |
]) | |
# Apply transforms | |
img_processed = transform(img_processed) | |
# Convert to tensor | |
img_tensor = torch.from_numpy(img_processed) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(img_tensor.unsqueeze(0)) | |
probabilities = output.squeeze().numpy() | |
# Create dictionary of results | |
results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)} | |
elif MODELS[model_name] == "efficientnet_b3_custom": | |
# Use the custom EfficientNet-B3 module for preprocessing and prediction | |
img_tensor = efficientnet_b3_custom.preprocess_image(image) | |
# Make prediction | |
results = efficientnet_b3_custom.predict(model, img_tensor) | |
else: # Other EfficientNet models | |
# Preprocess the image for EfficientNet | |
img_tensor = preprocess_image_efficientnet(image) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(img_tensor.unsqueeze(0)) | |
probabilities = torch.sigmoid(output).squeeze().numpy() | |
# Create dictionary of results | |
results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)} | |
# Sort results by probability (descending) | |
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) | |
# Get top pathologies above threshold | |
top_pathologies = [p for p, prob in sorted_results.items() if prob >= confidence_threshold] | |
# Generate Grad-CAM for top pathologies | |
gradcam_img = None | |
if top_pathologies: | |
# Get index of top pathology | |
if model_name == "DenseNet121" or "EfficientNet" in model_name: | |
top_pathology = top_pathologies[0] | |
target_idx = NIH_PATHOLOGIES.index(top_pathology) | |
# Compute Grad-CAM using unified approach | |
heatmap = compute_unified_gradcam(model, model_name, img_tensor, target_idx) | |
if heatmap is not None: | |
# Apply Grad-CAM overlay using unified approach | |
gradcam_img = apply_unified_gradcam(original_img, heatmap) | |
# If no Grad-CAM was generated, return the original image | |
if gradcam_img is None: | |
gradcam_img = original_img | |
# Format results for display | |
result_html = "<h3>Detected Conditions:</h3><ul>" | |
detected_count = 0 | |
for pathology, prob in sorted_results.items(): | |
if prob >= confidence_threshold: | |
detected_count += 1 | |
result_html += f"<li><b>{pathology}</b>: {prob:.4f} ({prob * 100:.1f}%)</li>" | |
if detected_count == 0: | |
result_html += "<li>No conditions detected above the confidence threshold.</li>" | |
result_html += "</ul>" | |
# Add a section for all probabilities | |
result_html += "<h3>All Probabilities:</h3><ul>" | |
for pathology, prob in sorted_results.items(): | |
if pathology in top_pathologies: | |
result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%) - <span style='color:red'>Used for Grad-CAM</span></li>" | |
else: | |
result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%)</li>" | |
result_html += "</ul>" | |
# Add explanation about Grad-CAM | |
result_html += "<h3>About Grad-CAM:</h3>" | |
result_html += "<p>Grad-CAM highlights regions that influenced the model's prediction for the detected conditions. " | |
result_html += "Red/yellow areas indicate regions of high importance for the diagnosis.</p>" | |
# Add explanation about the improved visualization | |
result_html += "<p><b>Improved Visualization:</b> This enhanced Grad-CAM uses medical imaging-specific techniques to better highlight clinically relevant regions. " | |
if len(top_pathologies) > 1: | |
result_html += f"The visualization combines information from {len(top_pathologies)} detected conditions for a more comprehensive view.</p>" | |
else: | |
result_html += "The visualization focuses on the most significant condition detected in the X-ray.</p>" | |
return gradcam_img, result_html | |
except Exception as e: | |
import traceback | |
traceback_str = traceback.format_exc() | |
print(f"Error processing image: {str(e)}") | |
print(traceback_str) | |
return None, f"Error processing image: {str(e)}" | |
# Create the Gradio interface | |
with gr.Blocks(title="Chest X-ray Disease Classifier with Improved Grad-CAM") as demo: | |
gr.Markdown("# Chest X-ray Disease Classifier with Improved Grad-CAM") | |
gr.Markdown("Upload a chest X-ray image and select a model to detect potential conditions. The enhanced Grad-CAM visualization will highlight clinically relevant regions influencing the diagnosis.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="Upload Chest X-ray Image", type="numpy") | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value="EfficientNet-B3", # Default to the DannyNet model | |
label="Select Model" | |
) | |
confidence = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.5, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
submit_button = gr.Button("Analyze X-ray") | |
with gr.Column(scale=2): | |
# REMOVED: Removed the original image display as requested by the user | |
gradcam_image = gr.Image(label="Grad-CAM Visualization") | |
output_text = gr.HTML(label="Results") | |
submit_button.click( | |
fn=predict_with_gradcam, | |
inputs=[input_image, model_dropdown, confidence], | |
outputs=[gradcam_image, output_text] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |