ChestXrayClassfication / efficientnet_b3_custom.py
omarDevs's picture
Upload 3 files
d93b771 verified
"""
EfficientNet-B3 Model Loader for NIH Chest X-ray Classification
This module provides functions to load and use the EfficientNet-B3 model
trained on the NIH ChestX-ray14 dataset with advanced techniques.
"""
import os
import torch
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights
import torchvision.transforms as transforms
import numpy as np
import cv2
from PIL import Image
# Disease labels - same order as in the original NIH dataset
DISEASE_LIST = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]
# Define the model architecture (DenseNet121 base with custom classifier)
class CustomEfficientNetB3(nn.Module):
def __init__(self, num_classes=14):
super().__init__()
base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
self.features = base_model.features
self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)
def forward(self, x):
x = self.features(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
return self.classifier(x)
# Image preprocessing function
def preprocess_image(img):
"""
Preprocess an image for the model
Args:
img: PIL Image or numpy array
Returns:
torch.Tensor: Preprocessed image tensor
"""
# Convert to PIL Image if it's a numpy array
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
# Ensure image is in RGB mode
img = img.convert('RGB')
# Define preprocessing transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(img)
def load_model(model_path, device='cpu'):
"""
Load the model from a checkpoint file
Args:
model_path: Path to the model checkpoint
device: Device to load the model on ('cpu' or 'cuda')
Returns:
model: Loaded model
"""
# Create model architecture
model = CustomEfficientNetB3(num_classes=14)
# Load weights with error handling for different PyTorch versions
try:
# Try loading with weights_only=False first (for PyTorch <2.6)
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# If checkpoint contains a state_dict key
model.load_state_dict(checkpoint['state_dict'])
else:
# If checkpoint is the state_dict itself
model.load_state_dict(checkpoint)
except Exception as e:
print(f"Error loading with weights_only=False: {e}")
try:
# Add numpy scalar to safe globals as fallback
import torch.serialization
import numpy as np
torch.serialization.add_safe_globals([np.core.multiarray.scalar])
# Try loading with default settings
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
except Exception as e2:
raise RuntimeError(f"Failed to load model: {e2}")
model.to(device)
model.eval()
return model
def predict(model, img_tensor, device='cpu'):
"""
Make predictions with the model
Args:
model: Model
img_tensor: Preprocessed image tensor
device: Device to run inference on
Returns:
dict: Dictionary mapping disease names to probabilities
"""
with torch.no_grad():
output = model(img_tensor.unsqueeze(0).to(device))
probs = torch.sigmoid(output[0]).cpu().numpy()
# Create dictionary of results
results = {disease: float(prob) for disease, prob in zip(DISEASE_LIST, probs)}
return results
def register_hooks(model):
"""
Register hooks for Grad-CAM
Args:
model: Model
Returns:
activation_dict: Dictionary to store activations
gradient_dict: Dictionary to store gradients
"""
activation_dict = {}
gradient_dict = {}
def get_activation(name):
def hook(module, input, output):
activation_dict[name] = output
return hook
def get_gradient(name):
def hook(grad):
gradient_dict[name] = grad
return hook
# Register hooks on the last dense block for better feature visualization
target_layer = model.features[-1]
target_layer.register_forward_hook(get_activation('target_layer'))
return activation_dict, gradient_dict
def compute_gradcam(model, img_tensor, target_class_idx=None, device='cpu'):
"""
Compute Grad-CAM for the model
Args:
model: Model
img_tensor: Preprocessed image tensor
target_class_idx: Index of the target class (if None, uses the highest probability class)
device: Device to run on
Returns:
numpy.ndarray: Grad-CAM heatmap (224x224)
"""
# Register hooks
activation_dict, gradient_dict = register_hooks(model)
# Clone the tensor to avoid modifying the original
img_tensor_for_gradcam = img_tensor.clone().to(device)
img_tensor_for_gradcam.requires_grad_(True)
# Forward pass
model.zero_grad()
output = model(img_tensor_for_gradcam.unsqueeze(0))
# If target_class is None, use the class with the highest score
if target_class_idx is None:
target_class_idx = torch.argmax(output).item()
# Target for backprop
one_hot = torch.zeros_like(output)
one_hot[0, target_class_idx] = 1
# Backward pass with retain_graph to avoid "backward through the graph a second time" error
output.backward(gradient=one_hot, retain_graph=True)
# Get activations
activations = activation_dict['target_layer']
# Try different approaches to get gradients
try:
# First try: direct gradient access if activation has grad_fn
if activations.grad_fn is not None:
gradients = torch.autograd.grad(output[:, target_class_idx].sum(),
activations,
retain_graph=True)[0]
except Exception as e:
print(f"First gradient approach failed: {e}")
try:
# Second try: use the gradient from the model's parameters
# Find the appropriate convolutional layer's parameters
target_layer = model.features[-1]
# Get gradients from parameters
params = [p for p in target_layer.parameters() if p.requires_grad]
if params:
# Use the gradient of the first parameter as a proxy
gradients = params[0].grad
# Reshape if needed to match activation shape
if gradients is not None and gradients.shape != activations.shape:
# This is a fallback that might not be accurate but better than nothing
print("Warning: Gradient shape mismatch, using alternative approach")
# Create a dummy gradient of the right shape
gradients = torch.ones_like(activations)
except Exception as e:
print(f"Second gradient approach failed: {e}")
# If we still don't have gradients, create a dummy gradient as last resort
if 'gradients' not in locals() or gradients is None:
print("Warning: Could not compute gradients, using dummy gradients")
gradients = torch.ones_like(activations)
# Use global average pooling with absolute values for better feature highlighting
# This helps focus on the magnitude of importance rather than just direction
pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3])
# Weight activation maps with gradients
for i in range(activations.size(1)):
activations[:, i, :, :] *= pooled_gradients[i]
# Sum along channels for final heatmap
heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy()
# ReLU on the heatmap
heatmap = np.maximum(heatmap, 0)
# Apply gamma correction to enhance contrast
gamma = 0.7 # Values less than 1 enhance bright regions
heatmap = np.power(heatmap, gamma)
# Normalize heatmap
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
# Apply threshold to remove noise
# This helps focus on the most important regions
threshold = 0.2 # Only keep values above 20% of max
heatmap[heatmap < threshold] = 0
# Re-normalize after thresholding
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
# Resize to 224x224
heatmap = cv2.resize(heatmap, (224, 224))
return heatmap
def apply_gradcam(original_img, heatmap, alpha=0.6):
"""
Apply Grad-CAM heatmap to the original image
Args:
original_img: PIL Image or numpy array
heatmap: Grad-CAM heatmap
alpha: Transparency factor
Returns:
numpy.ndarray: Image with heatmap overlay
"""
# Convert to numpy if it's a PIL Image
if isinstance(original_img, Image.Image):
original_img = np.array(original_img)
# Resize original image if needed
original_img = cv2.resize(original_img, (224, 224))
# Convert original image to RGB if it's grayscale
if len(original_img.shape) == 2:
original_img = np.stack([original_img] * 3, axis=2)
elif len(original_img.shape) == 3 and original_img.shape[2] == 1:
original_img = np.concatenate([original_img] * 3, axis=2)
# Convert heatmap to uint8 before applying median blur
heatmap_uint8 = np.uint8(heatmap * 255)
heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7)
# Convert back to float in range [0,1]
heatmap = heatmap_blurred.astype(float) / 255.0
# Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT)
# Convert to RGB if needed
if len(original_img.shape) == 3 and original_img.shape[2] == 3:
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
# Create a copy of the original image for overlay
original_img_float = original_img.astype(float)
# Superimpose heatmap on original image
superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5)
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
# Add contour lines for the most significant regions
# This helps delineate the affected areas more clearly
binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255
contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1)
return superimposed_img