""" Chest X-ray Classification App with Multiple Models Integrates multiple models with unified Grad-CAM visualization """ import os import torch import numpy as np import gradio as gr import torchxrayvision as xrv from PIL import Image import torchvision.transforms as transforms from torchvision import models import torch.nn as nn import numpy import torch.serialization import matplotlib.pyplot as plt import cv2 import matplotlib matplotlib.use('Agg') # Use non-interactive backend # Import the custom EfficientNet-B3 model from separate module # This is actually the DannyNet model but renamed for consistency import efficientnet_b3_custom torch.serialization.add_safe_globals([numpy.core.multiarray.scalar]) # Define the available models - removed parenthetical descriptors MODELS = { "DenseNet121": "densenet121-res224-nih", "EfficientNet-B3": "efficientnet_b3_custom", # This is the DannyNet model "EfficientNet-B3 O": "efficientnet_b3", "EfficientNet-B0": "efficientnet_b0" } # NIH ChestX-ray14 pathologies NIH_PATHOLOGIES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia' ] # Cache for loaded models loaded_models = {} # For Grad-CAM activation = {} gradient = {} # Hook functions for Grad-CAM def get_activation(name): def hook(model, input, output): activation[name] = output # Keep gradient tracking return hook def get_gradient(name): def hook(grad): gradient[name] = grad return hook def load_model(model_name): if model_name in loaded_models: return loaded_models[model_name] model_type = MODELS[model_name] try: if model_name == "DenseNet121": # Load DenseNet121 from TorchXRayVision model = xrv.models.DenseNet(weights="densenet121-res224-nih") model.eval() loaded_models[model_name] = model return model elif model_type == "efficientnet_b3_custom": # Load the custom EfficientNet-B3 model (actually DannyNet) model_path = "dannynet-55-best_model_20250422-211522.pth" if os.path.exists(model_path): model = efficientnet_b3_custom.load_model(model_path, device='cpu') print(f"Successfully loaded EfficientNet-B3 from {model_path}") loaded_models[model_name] = model return model else: print(f"Model file not found: {model_path}") print("Please place the model file in the same directory as this script.") return None elif model_type == "efficientnet_b3": # Import EfficientNet dynamically to avoid dependency issues try: from efficientnet_pytorch import EfficientNet model = EfficientNet.from_name('efficientnet-b3', num_classes=14) except ImportError: # Fallback to torchvision if efficientnet_pytorch is not available model = models.efficientnet_b3(pretrained=False) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, 14) # Load your trained weights model_path = os.path.join("weights", "best_model_b3.pt") if os.path.exists(model_path): # Explicitly set weights_only=False for PyTorch 2.6+ compatibility checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"], strict=False) else: model.load_state_dict(checkpoint, strict=False) print(f"Successfully loaded EfficientNet-B3 Original from {model_path}") else: print(f"Model file not found: {model_path}") print("Using a new model instance. Please place your trained model in the weights directory.") elif model_type == "efficientnet_b0": # Import EfficientNet dynamically to avoid dependency issues try: from efficientnet_pytorch import EfficientNet model = EfficientNet.from_name('efficientnet-b0', num_classes=14) except ImportError: # Fallback to torchvision if efficientnet_pytorch is not available model = models.efficientnet_b0(pretrained=False) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, 14) # Load your trained weights model_path = os.path.join("weights", "best_model_b0.pt") if os.path.exists(model_path): # Explicitly set weights_only=False for PyTorch 2.6+ compatibility checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"], strict=False) else: model.load_state_dict(checkpoint, strict=False) print(f"Successfully loaded EfficientNet-B0 from {model_path}") else: print(f"Model file not found: {model_path}") print("Using a new model instance. Please place your trained model in the weights directory.") model.eval() loaded_models[model_name] = model return model except Exception as e: print(f"Error loading model {model_name}: {e}") return None def preprocess_image_densenet(img): """Preprocess an image for the DenseNet model.""" # Convert to grayscale if it's a color image if len(img.shape) > 2: img = img.mean(2) # Normalize the image img = xrv.datasets.normalize(img, 255) # Add channel dimension if len(img.shape) == 2: img = img[None, ...] return img def preprocess_image_efficientnet(img, img_size=224): """Preprocess an image for the EfficientNet models.""" # Convert to PIL Image if it's a numpy array if isinstance(img, np.ndarray): img = Image.fromarray(img) # Ensure image is in RGB mode img = img.convert('RGB') # Define preprocessing transforms transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Apply transforms img_tensor = transform(img) return img_tensor # Unified Grad-CAM implementation for all models def compute_unified_gradcam(model, model_name, img_tensor, target_class_idx): """ Compute Grad-CAM using a unified approach for all models This ensures consistent visualization style across all models """ # For the custom EfficientNet-B3 model (DannyNet), use its dedicated function if MODELS[model_name] == "efficientnet_b3_custom": return efficientnet_b3_custom.compute_gradcam(model, img_tensor, target_class_idx) # For other models, implement a similar approach to ensure consistent results # Register hooks activation.clear() gradient.clear() # Select appropriate target layer based on model type if model_name == "DenseNet121": target_layer = model.features.denseblock3 elif "EfficientNet" in model_name: if hasattr(model, '_blocks'): middle_idx = len(model._blocks) // 2 target_layer = model._blocks[middle_idx] else: middle_idx = len(model.features) // 2 target_layer = model.features[middle_idx] else: # Default to last feature layer target_layer = model.features[-1] # Register forward hook handle = target_layer.register_forward_hook(get_activation('target_layer')) # Ensure input tensor requires gradients img_tensor_for_gradcam = img_tensor.clone().requires_grad_(True) # Forward pass model.zero_grad() # Handle different model output formats if model_name == "DenseNet121": output = model(img_tensor_for_gradcam.unsqueeze(0)) else: output = model(img_tensor_for_gradcam.unsqueeze(0)) output = torch.sigmoid(output) # Target for backprop if target_class_idx is not None: score = output[0, target_class_idx] else: score, _ = output.max(dim=1) score = score[0] # Backward pass with retain_graph to avoid errors score.backward(retain_graph=True) # Clean up hook handle.remove() # Get activations if 'target_layer' not in activation: print("No activation captured") return None activations = activation['target_layer'] # Try different approaches to get gradients try: # Get gradients using autograd gradients = torch.autograd.grad(score, activations, create_graph=True, retain_graph=True)[0] except Exception as e: print(f"Gradient calculation failed: {e}") # Create dummy gradients as fallback gradients = torch.ones_like(activations) # Use global average pooling with absolute values for better feature highlighting pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3]) # Weight activation maps with gradients for i in range(activations.size(1)): activations[:, i, :, :] *= pooled_gradients[i] # Sum along channels for final heatmap heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy() # ReLU on the heatmap heatmap = np.maximum(heatmap, 0) # Apply gamma correction to enhance contrast gamma = 0.7 # Values less than 1 enhance bright regions heatmap = np.power(heatmap, gamma) # Normalize heatmap if np.max(heatmap) > 0: heatmap = heatmap / np.max(heatmap) # Apply threshold to remove noise threshold = 0.2 # Only keep values above 20% of max heatmap[heatmap < threshold] = 0 # Re-normalize after thresholding if np.max(heatmap) > 0: heatmap = heatmap / np.max(heatmap) # Resize to 224x224 heatmap = cv2.resize(heatmap, (224, 224)) return heatmap # Unified Grad-CAM overlay function for all models def apply_unified_gradcam(original_img, heatmap, alpha=0.6): """ Apply Grad-CAM heatmap to the original image using a unified approach This ensures consistent visualization style across all models """ # Convert to numpy if it's a PIL Image if isinstance(original_img, Image.Image): original_img = np.array(original_img) # Resize original image to 224x224 original_img = cv2.resize(original_img, (224, 224)) # Convert original image to RGB if it's grayscale if len(original_img.shape) == 2: original_img = np.stack([original_img] * 3, axis=2) elif len(original_img.shape) == 3 and original_img.shape[2] == 1: original_img = np.concatenate([original_img] * 3, axis=2) # Convert heatmap to uint8 before applying median blur heatmap_uint8 = np.uint8(heatmap * 255) heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7) # Convert back to float in range [0,1] heatmap = heatmap_blurred.astype(float) / 255.0 # Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT) # Convert to RGB if needed if len(original_img.shape) == 3 and original_img.shape[2] == 3: heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) # Create a copy of the original image for overlay original_img_float = original_img.astype(float) # Superimpose heatmap on original image superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5) superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8) # Add contour lines for the most significant regions binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255 contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1) return superimposed_img def predict_with_gradcam(image, model_name, confidence_threshold=0.5): """Make predictions and generate Grad-CAM visualization.""" if image is None: return None, "Please upload an image." try: # Create weights directory if it doesn't exist (for EfficientNet models) if "EfficientNet" in model_name: os.makedirs("weights", exist_ok=True) # Load the model model = load_model(model_name) if model is None: return None, f"Failed to load model {model_name}. Please check the console for details." # Save original image for visualization original_img = np.array(image).copy() # Process based on model type if model_name == "DenseNet121": # Read and preprocess the image for DenseNet from TorchXRayVision img = np.array(Image.fromarray(image).convert('RGB')) img_processed = preprocess_image_densenet(img) # Create transforms transform = transforms.Compose([ xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224) ]) # Apply transforms img_processed = transform(img_processed) # Convert to tensor img_tensor = torch.from_numpy(img_processed) # Make prediction with torch.no_grad(): output = model(img_tensor.unsqueeze(0)) probabilities = output.squeeze().numpy() # Create dictionary of results results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)} elif MODELS[model_name] == "efficientnet_b3_custom": # Use the custom EfficientNet-B3 module for preprocessing and prediction img_tensor = efficientnet_b3_custom.preprocess_image(image) # Make prediction results = efficientnet_b3_custom.predict(model, img_tensor) else: # Other EfficientNet models # Preprocess the image for EfficientNet img_tensor = preprocess_image_efficientnet(image) # Make prediction with torch.no_grad(): output = model(img_tensor.unsqueeze(0)) probabilities = torch.sigmoid(output).squeeze().numpy() # Create dictionary of results results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)} # Sort results by probability (descending) sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) # Get top pathologies above threshold top_pathologies = [p for p, prob in sorted_results.items() if prob >= confidence_threshold] # Generate Grad-CAM for top pathologies gradcam_img = None if top_pathologies: # Get index of top pathology if model_name == "DenseNet121" or "EfficientNet" in model_name: top_pathology = top_pathologies[0] target_idx = NIH_PATHOLOGIES.index(top_pathology) # Compute Grad-CAM using unified approach heatmap = compute_unified_gradcam(model, model_name, img_tensor, target_idx) if heatmap is not None: # Apply Grad-CAM overlay using unified approach gradcam_img = apply_unified_gradcam(original_img, heatmap) # If no Grad-CAM was generated, return the original image if gradcam_img is None: gradcam_img = original_img # Format results for display result_html = "
Grad-CAM highlights regions that influenced the model's prediction for the detected conditions. " result_html += "Red/yellow areas indicate regions of high importance for the diagnosis.
" # Add explanation about the improved visualization result_html += "Improved Visualization: This enhanced Grad-CAM uses medical imaging-specific techniques to better highlight clinically relevant regions. " if len(top_pathologies) > 1: result_html += f"The visualization combines information from {len(top_pathologies)} detected conditions for a more comprehensive view.
" else: result_html += "The visualization focuses on the most significant condition detected in the X-ray." return gradcam_img, result_html except Exception as e: import traceback traceback_str = traceback.format_exc() print(f"Error processing image: {str(e)}") print(traceback_str) return None, f"Error processing image: {str(e)}" # Create the Gradio interface with gr.Blocks(title="Chest X-ray Disease Classifier with Improved Grad-CAM") as demo: gr.Markdown("# Chest X-ray Disease Classifier with Improved Grad-CAM") gr.Markdown("Upload a chest X-ray image and select a model to detect potential conditions. The enhanced Grad-CAM visualization will highlight clinically relevant regions influencing the diagnosis.") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Upload Chest X-ray Image", type="numpy") model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value="EfficientNet-B3", # Default to the DannyNet model label="Select Model" ) confidence = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold" ) submit_button = gr.Button("Analyze X-ray") with gr.Column(scale=2): # REMOVED: Removed the original image display as requested by the user gradcam_image = gr.Image(label="Grad-CAM Visualization") output_text = gr.HTML(label="Results") submit_button.click( fn=predict_with_gradcam, inputs=[input_image, model_dropdown, confidence], outputs=[gradcam_image, output_text] ) # Launch the app if __name__ == "__main__": demo.launch()