omarDevs's picture
Rename final_chest_xray_app_with.py to app.py
934a6b5 verified
"""
Chest X-ray Classification App with Multiple Models
Integrates multiple models with unified Grad-CAM visualization
"""
import os
import torch
import numpy as np
import gradio as gr
import torchxrayvision as xrv
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import numpy
import torch.serialization
import matplotlib.pyplot as plt
import cv2
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
# Import the custom EfficientNet-B3 model from separate module
# This is actually the DannyNet model but renamed for consistency
import efficientnet_b3_custom
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])
# Define the available models - removed parenthetical descriptors
MODELS = {
"DenseNet121": "densenet121-res224-nih",
"EfficientNet-B3": "efficientnet_b3_custom", # This is the DannyNet model
"EfficientNet-B3 O": "efficientnet_b3",
"EfficientNet-B0": "efficientnet_b0"
}
# NIH ChestX-ray14 pathologies
NIH_PATHOLOGIES = [
'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]
# Cache for loaded models
loaded_models = {}
# For Grad-CAM
activation = {}
gradient = {}
# Hook functions for Grad-CAM
def get_activation(name):
def hook(model, input, output):
activation[name] = output # Keep gradient tracking
return hook
def get_gradient(name):
def hook(grad):
gradient[name] = grad
return hook
def load_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
model_type = MODELS[model_name]
try:
if model_name == "DenseNet121":
# Load DenseNet121 from TorchXRayVision
model = xrv.models.DenseNet(weights="densenet121-res224-nih")
model.eval()
loaded_models[model_name] = model
return model
elif model_type == "efficientnet_b3_custom":
# Load the custom EfficientNet-B3 model (actually DannyNet)
model_path = "dannynet-55-best_model_20250422-211522.pth"
if os.path.exists(model_path):
model = efficientnet_b3_custom.load_model(model_path, device='cpu')
print(f"Successfully loaded EfficientNet-B3 from {model_path}")
loaded_models[model_name] = model
return model
else:
print(f"Model file not found: {model_path}")
print("Please place the model file in the same directory as this script.")
return None
elif model_type == "efficientnet_b3":
# Import EfficientNet dynamically to avoid dependency issues
try:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b3', num_classes=14)
except ImportError:
# Fallback to torchvision if efficientnet_pytorch is not available
model = models.efficientnet_b3(pretrained=False)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 14)
# Load your trained weights
model_path = os.path.join("weights", "best_model_b3.pt")
if os.path.exists(model_path):
# Explicitly set weights_only=False for PyTorch 2.6+ compatibility
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
print(f"Successfully loaded EfficientNet-B3 Original from {model_path}")
else:
print(f"Model file not found: {model_path}")
print("Using a new model instance. Please place your trained model in the weights directory.")
elif model_type == "efficientnet_b0":
# Import EfficientNet dynamically to avoid dependency issues
try:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b0', num_classes=14)
except ImportError:
# Fallback to torchvision if efficientnet_pytorch is not available
model = models.efficientnet_b0(pretrained=False)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 14)
# Load your trained weights
model_path = os.path.join("weights", "best_model_b0.pt")
if os.path.exists(model_path):
# Explicitly set weights_only=False for PyTorch 2.6+ compatibility
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
print(f"Successfully loaded EfficientNet-B0 from {model_path}")
else:
print(f"Model file not found: {model_path}")
print("Using a new model instance. Please place your trained model in the weights directory.")
model.eval()
loaded_models[model_name] = model
return model
except Exception as e:
print(f"Error loading model {model_name}: {e}")
return None
def preprocess_image_densenet(img):
"""Preprocess an image for the DenseNet model."""
# Convert to grayscale if it's a color image
if len(img.shape) > 2:
img = img.mean(2)
# Normalize the image
img = xrv.datasets.normalize(img, 255)
# Add channel dimension
if len(img.shape) == 2:
img = img[None, ...]
return img
def preprocess_image_efficientnet(img, img_size=224):
"""Preprocess an image for the EfficientNet models."""
# Convert to PIL Image if it's a numpy array
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
# Ensure image is in RGB mode
img = img.convert('RGB')
# Define preprocessing transforms
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply transforms
img_tensor = transform(img)
return img_tensor
# Unified Grad-CAM implementation for all models
def compute_unified_gradcam(model, model_name, img_tensor, target_class_idx):
"""
Compute Grad-CAM using a unified approach for all models
This ensures consistent visualization style across all models
"""
# For the custom EfficientNet-B3 model (DannyNet), use its dedicated function
if MODELS[model_name] == "efficientnet_b3_custom":
return efficientnet_b3_custom.compute_gradcam(model, img_tensor, target_class_idx)
# For other models, implement a similar approach to ensure consistent results
# Register hooks
activation.clear()
gradient.clear()
# Select appropriate target layer based on model type
if model_name == "DenseNet121":
target_layer = model.features.denseblock3
elif "EfficientNet" in model_name:
if hasattr(model, '_blocks'):
middle_idx = len(model._blocks) // 2
target_layer = model._blocks[middle_idx]
else:
middle_idx = len(model.features) // 2
target_layer = model.features[middle_idx]
else:
# Default to last feature layer
target_layer = model.features[-1]
# Register forward hook
handle = target_layer.register_forward_hook(get_activation('target_layer'))
# Ensure input tensor requires gradients
img_tensor_for_gradcam = img_tensor.clone().requires_grad_(True)
# Forward pass
model.zero_grad()
# Handle different model output formats
if model_name == "DenseNet121":
output = model(img_tensor_for_gradcam.unsqueeze(0))
else:
output = model(img_tensor_for_gradcam.unsqueeze(0))
output = torch.sigmoid(output)
# Target for backprop
if target_class_idx is not None:
score = output[0, target_class_idx]
else:
score, _ = output.max(dim=1)
score = score[0]
# Backward pass with retain_graph to avoid errors
score.backward(retain_graph=True)
# Clean up hook
handle.remove()
# Get activations
if 'target_layer' not in activation:
print("No activation captured")
return None
activations = activation['target_layer']
# Try different approaches to get gradients
try:
# Get gradients using autograd
gradients = torch.autograd.grad(score, activations,
create_graph=True, retain_graph=True)[0]
except Exception as e:
print(f"Gradient calculation failed: {e}")
# Create dummy gradients as fallback
gradients = torch.ones_like(activations)
# Use global average pooling with absolute values for better feature highlighting
pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3])
# Weight activation maps with gradients
for i in range(activations.size(1)):
activations[:, i, :, :] *= pooled_gradients[i]
# Sum along channels for final heatmap
heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy()
# ReLU on the heatmap
heatmap = np.maximum(heatmap, 0)
# Apply gamma correction to enhance contrast
gamma = 0.7 # Values less than 1 enhance bright regions
heatmap = np.power(heatmap, gamma)
# Normalize heatmap
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
# Apply threshold to remove noise
threshold = 0.2 # Only keep values above 20% of max
heatmap[heatmap < threshold] = 0
# Re-normalize after thresholding
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
# Resize to 224x224
heatmap = cv2.resize(heatmap, (224, 224))
return heatmap
# Unified Grad-CAM overlay function for all models
def apply_unified_gradcam(original_img, heatmap, alpha=0.6):
"""
Apply Grad-CAM heatmap to the original image using a unified approach
This ensures consistent visualization style across all models
"""
# Convert to numpy if it's a PIL Image
if isinstance(original_img, Image.Image):
original_img = np.array(original_img)
# Resize original image to 224x224
original_img = cv2.resize(original_img, (224, 224))
# Convert original image to RGB if it's grayscale
if len(original_img.shape) == 2:
original_img = np.stack([original_img] * 3, axis=2)
elif len(original_img.shape) == 3 and original_img.shape[2] == 1:
original_img = np.concatenate([original_img] * 3, axis=2)
# Convert heatmap to uint8 before applying median blur
heatmap_uint8 = np.uint8(heatmap * 255)
heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7)
# Convert back to float in range [0,1]
heatmap = heatmap_blurred.astype(float) / 255.0
# Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT)
# Convert to RGB if needed
if len(original_img.shape) == 3 and original_img.shape[2] == 3:
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
# Create a copy of the original image for overlay
original_img_float = original_img.astype(float)
# Superimpose heatmap on original image
superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5)
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
# Add contour lines for the most significant regions
binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255
contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1)
return superimposed_img
def predict_with_gradcam(image, model_name, confidence_threshold=0.5):
"""Make predictions and generate Grad-CAM visualization."""
if image is None:
return None, "Please upload an image."
try:
# Create weights directory if it doesn't exist (for EfficientNet models)
if "EfficientNet" in model_name:
os.makedirs("weights", exist_ok=True)
# Load the model
model = load_model(model_name)
if model is None:
return None, f"Failed to load model {model_name}. Please check the console for details."
# Save original image for visualization
original_img = np.array(image).copy()
# Process based on model type
if model_name == "DenseNet121":
# Read and preprocess the image for DenseNet from TorchXRayVision
img = np.array(Image.fromarray(image).convert('RGB'))
img_processed = preprocess_image_densenet(img)
# Create transforms
transform = transforms.Compose([
xrv.datasets.XRayCenterCrop(),
xrv.datasets.XRayResizer(224)
])
# Apply transforms
img_processed = transform(img_processed)
# Convert to tensor
img_tensor = torch.from_numpy(img_processed)
# Make prediction
with torch.no_grad():
output = model(img_tensor.unsqueeze(0))
probabilities = output.squeeze().numpy()
# Create dictionary of results
results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)}
elif MODELS[model_name] == "efficientnet_b3_custom":
# Use the custom EfficientNet-B3 module for preprocessing and prediction
img_tensor = efficientnet_b3_custom.preprocess_image(image)
# Make prediction
results = efficientnet_b3_custom.predict(model, img_tensor)
else: # Other EfficientNet models
# Preprocess the image for EfficientNet
img_tensor = preprocess_image_efficientnet(image)
# Make prediction
with torch.no_grad():
output = model(img_tensor.unsqueeze(0))
probabilities = torch.sigmoid(output).squeeze().numpy()
# Create dictionary of results
results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)}
# Sort results by probability (descending)
sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
# Get top pathologies above threshold
top_pathologies = [p for p, prob in sorted_results.items() if prob >= confidence_threshold]
# Generate Grad-CAM for top pathologies
gradcam_img = None
if top_pathologies:
# Get index of top pathology
if model_name == "DenseNet121" or "EfficientNet" in model_name:
top_pathology = top_pathologies[0]
target_idx = NIH_PATHOLOGIES.index(top_pathology)
# Compute Grad-CAM using unified approach
heatmap = compute_unified_gradcam(model, model_name, img_tensor, target_idx)
if heatmap is not None:
# Apply Grad-CAM overlay using unified approach
gradcam_img = apply_unified_gradcam(original_img, heatmap)
# If no Grad-CAM was generated, return the original image
if gradcam_img is None:
gradcam_img = original_img
# Format results for display
result_html = "<h3>Detected Conditions:</h3><ul>"
detected_count = 0
for pathology, prob in sorted_results.items():
if prob >= confidence_threshold:
detected_count += 1
result_html += f"<li><b>{pathology}</b>: {prob:.4f} ({prob * 100:.1f}%)</li>"
if detected_count == 0:
result_html += "<li>No conditions detected above the confidence threshold.</li>"
result_html += "</ul>"
# Add a section for all probabilities
result_html += "<h3>All Probabilities:</h3><ul>"
for pathology, prob in sorted_results.items():
if pathology in top_pathologies:
result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%) - <span style='color:red'>Used for Grad-CAM</span></li>"
else:
result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%)</li>"
result_html += "</ul>"
# Add explanation about Grad-CAM
result_html += "<h3>About Grad-CAM:</h3>"
result_html += "<p>Grad-CAM highlights regions that influenced the model's prediction for the detected conditions. "
result_html += "Red/yellow areas indicate regions of high importance for the diagnosis.</p>"
# Add explanation about the improved visualization
result_html += "<p><b>Improved Visualization:</b> This enhanced Grad-CAM uses medical imaging-specific techniques to better highlight clinically relevant regions. "
if len(top_pathologies) > 1:
result_html += f"The visualization combines information from {len(top_pathologies)} detected conditions for a more comprehensive view.</p>"
else:
result_html += "The visualization focuses on the most significant condition detected in the X-ray.</p>"
return gradcam_img, result_html
except Exception as e:
import traceback
traceback_str = traceback.format_exc()
print(f"Error processing image: {str(e)}")
print(traceback_str)
return None, f"Error processing image: {str(e)}"
# Create the Gradio interface
with gr.Blocks(title="Chest X-ray Disease Classifier with Improved Grad-CAM") as demo:
gr.Markdown("# Chest X-ray Disease Classifier with Improved Grad-CAM")
gr.Markdown("Upload a chest X-ray image and select a model to detect potential conditions. The enhanced Grad-CAM visualization will highlight clinically relevant regions influencing the diagnosis.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Chest X-ray Image", type="numpy")
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="EfficientNet-B3", # Default to the DannyNet model
label="Select Model"
)
confidence = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Confidence Threshold"
)
submit_button = gr.Button("Analyze X-ray")
with gr.Column(scale=2):
# REMOVED: Removed the original image display as requested by the user
gradcam_image = gr.Image(label="Grad-CAM Visualization")
output_text = gr.HTML(label="Results")
submit_button.click(
fn=predict_with_gradcam,
inputs=[input_image, model_dropdown, confidence],
outputs=[gradcam_image, output_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()