Spaces:
Sleeping
Sleeping
""" | |
EfficientNet-B3 Model Loader for NIH Chest X-ray Classification | |
This module provides functions to load and use the EfficientNet-B3 model | |
trained on the NIH ChestX-ray14 dataset with advanced techniques. | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
from torchvision.models import densenet121, DenseNet121_Weights | |
import torchvision.transforms as transforms | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
# Disease labels - same order as in the original NIH dataset | |
DISEASE_LIST = [ | |
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', | |
'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', | |
'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax' | |
] | |
# Define the model architecture (DenseNet121 base with custom classifier) | |
class CustomEfficientNetB3(nn.Module): | |
def __init__(self, num_classes=14): | |
super().__init__() | |
base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1) | |
self.features = base_model.features | |
self.classifier = nn.Linear(base_model.classifier.in_features, num_classes) | |
def forward(self, x): | |
x = self.features(x) | |
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) | |
x = torch.flatten(x, 1) | |
return self.classifier(x) | |
# Image preprocessing function | |
def preprocess_image(img): | |
""" | |
Preprocess an image for the model | |
Args: | |
img: PIL Image or numpy array | |
Returns: | |
torch.Tensor: Preprocessed image tensor | |
""" | |
# Convert to PIL Image if it's a numpy array | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
# Ensure image is in RGB mode | |
img = img.convert('RGB') | |
# Define preprocessing transforms | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
return transform(img) | |
def load_model(model_path, device='cpu'): | |
""" | |
Load the model from a checkpoint file | |
Args: | |
model_path: Path to the model checkpoint | |
device: Device to load the model on ('cpu' or 'cuda') | |
Returns: | |
model: Loaded model | |
""" | |
# Create model architecture | |
model = CustomEfficientNetB3(num_classes=14) | |
# Load weights with error handling for different PyTorch versions | |
try: | |
# Try loading with weights_only=False first (for PyTorch <2.6) | |
checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
# If checkpoint contains a state_dict key | |
model.load_state_dict(checkpoint['state_dict']) | |
else: | |
# If checkpoint is the state_dict itself | |
model.load_state_dict(checkpoint) | |
except Exception as e: | |
print(f"Error loading with weights_only=False: {e}") | |
try: | |
# Add numpy scalar to safe globals as fallback | |
import torch.serialization | |
import numpy as np | |
torch.serialization.add_safe_globals([np.core.multiarray.scalar]) | |
# Try loading with default settings | |
checkpoint = torch.load(model_path, map_location=device) | |
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
except Exception as e2: | |
raise RuntimeError(f"Failed to load model: {e2}") | |
model.to(device) | |
model.eval() | |
return model | |
def predict(model, img_tensor, device='cpu'): | |
""" | |
Make predictions with the model | |
Args: | |
model: Model | |
img_tensor: Preprocessed image tensor | |
device: Device to run inference on | |
Returns: | |
dict: Dictionary mapping disease names to probabilities | |
""" | |
with torch.no_grad(): | |
output = model(img_tensor.unsqueeze(0).to(device)) | |
probs = torch.sigmoid(output[0]).cpu().numpy() | |
# Create dictionary of results | |
results = {disease: float(prob) for disease, prob in zip(DISEASE_LIST, probs)} | |
return results | |
def register_hooks(model): | |
""" | |
Register hooks for Grad-CAM | |
Args: | |
model: Model | |
Returns: | |
activation_dict: Dictionary to store activations | |
gradient_dict: Dictionary to store gradients | |
""" | |
activation_dict = {} | |
gradient_dict = {} | |
def get_activation(name): | |
def hook(module, input, output): | |
activation_dict[name] = output | |
return hook | |
def get_gradient(name): | |
def hook(grad): | |
gradient_dict[name] = grad | |
return hook | |
# Register hooks on the last dense block for better feature visualization | |
target_layer = model.features[-1] | |
target_layer.register_forward_hook(get_activation('target_layer')) | |
return activation_dict, gradient_dict | |
def compute_gradcam(model, img_tensor, target_class_idx=None, device='cpu'): | |
""" | |
Compute Grad-CAM for the model | |
Args: | |
model: Model | |
img_tensor: Preprocessed image tensor | |
target_class_idx: Index of the target class (if None, uses the highest probability class) | |
device: Device to run on | |
Returns: | |
numpy.ndarray: Grad-CAM heatmap (224x224) | |
""" | |
# Register hooks | |
activation_dict, gradient_dict = register_hooks(model) | |
# Clone the tensor to avoid modifying the original | |
img_tensor_for_gradcam = img_tensor.clone().to(device) | |
img_tensor_for_gradcam.requires_grad_(True) | |
# Forward pass | |
model.zero_grad() | |
output = model(img_tensor_for_gradcam.unsqueeze(0)) | |
# If target_class is None, use the class with the highest score | |
if target_class_idx is None: | |
target_class_idx = torch.argmax(output).item() | |
# Target for backprop | |
one_hot = torch.zeros_like(output) | |
one_hot[0, target_class_idx] = 1 | |
# Backward pass with retain_graph to avoid "backward through the graph a second time" error | |
output.backward(gradient=one_hot, retain_graph=True) | |
# Get activations | |
activations = activation_dict['target_layer'] | |
# Try different approaches to get gradients | |
try: | |
# First try: direct gradient access if activation has grad_fn | |
if activations.grad_fn is not None: | |
gradients = torch.autograd.grad(output[:, target_class_idx].sum(), | |
activations, | |
retain_graph=True)[0] | |
except Exception as e: | |
print(f"First gradient approach failed: {e}") | |
try: | |
# Second try: use the gradient from the model's parameters | |
# Find the appropriate convolutional layer's parameters | |
target_layer = model.features[-1] | |
# Get gradients from parameters | |
params = [p for p in target_layer.parameters() if p.requires_grad] | |
if params: | |
# Use the gradient of the first parameter as a proxy | |
gradients = params[0].grad | |
# Reshape if needed to match activation shape | |
if gradients is not None and gradients.shape != activations.shape: | |
# This is a fallback that might not be accurate but better than nothing | |
print("Warning: Gradient shape mismatch, using alternative approach") | |
# Create a dummy gradient of the right shape | |
gradients = torch.ones_like(activations) | |
except Exception as e: | |
print(f"Second gradient approach failed: {e}") | |
# If we still don't have gradients, create a dummy gradient as last resort | |
if 'gradients' not in locals() or gradients is None: | |
print("Warning: Could not compute gradients, using dummy gradients") | |
gradients = torch.ones_like(activations) | |
# Use global average pooling with absolute values for better feature highlighting | |
# This helps focus on the magnitude of importance rather than just direction | |
pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3]) | |
# Weight activation maps with gradients | |
for i in range(activations.size(1)): | |
activations[:, i, :, :] *= pooled_gradients[i] | |
# Sum along channels for final heatmap | |
heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy() | |
# ReLU on the heatmap | |
heatmap = np.maximum(heatmap, 0) | |
# Apply gamma correction to enhance contrast | |
gamma = 0.7 # Values less than 1 enhance bright regions | |
heatmap = np.power(heatmap, gamma) | |
# Normalize heatmap | |
if np.max(heatmap) > 0: | |
heatmap = heatmap / np.max(heatmap) | |
# Apply threshold to remove noise | |
# This helps focus on the most important regions | |
threshold = 0.2 # Only keep values above 20% of max | |
heatmap[heatmap < threshold] = 0 | |
# Re-normalize after thresholding | |
if np.max(heatmap) > 0: | |
heatmap = heatmap / np.max(heatmap) | |
# Resize to 224x224 | |
heatmap = cv2.resize(heatmap, (224, 224)) | |
return heatmap | |
def apply_gradcam(original_img, heatmap, alpha=0.6): | |
""" | |
Apply Grad-CAM heatmap to the original image | |
Args: | |
original_img: PIL Image or numpy array | |
heatmap: Grad-CAM heatmap | |
alpha: Transparency factor | |
Returns: | |
numpy.ndarray: Image with heatmap overlay | |
""" | |
# Convert to numpy if it's a PIL Image | |
if isinstance(original_img, Image.Image): | |
original_img = np.array(original_img) | |
# Resize original image if needed | |
original_img = cv2.resize(original_img, (224, 224)) | |
# Convert original image to RGB if it's grayscale | |
if len(original_img.shape) == 2: | |
original_img = np.stack([original_img] * 3, axis=2) | |
elif len(original_img.shape) == 3 and original_img.shape[2] == 1: | |
original_img = np.concatenate([original_img] * 3, axis=2) | |
# Convert heatmap to uint8 before applying median blur | |
heatmap_uint8 = np.uint8(heatmap * 255) | |
heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7) | |
# Convert back to float in range [0,1] | |
heatmap = heatmap_blurred.astype(float) / 255.0 | |
# Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization | |
heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT) | |
# Convert to RGB if needed | |
if len(original_img.shape) == 3 and original_img.shape[2] == 3: | |
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
# Create a copy of the original image for overlay | |
original_img_float = original_img.astype(float) | |
# Superimpose heatmap on original image | |
superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5) | |
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8) | |
# Add contour lines for the most significant regions | |
# This helps delineate the affected areas more clearly | |
binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255 | |
contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1) | |
return superimposed_img | |