File size: 11,554 Bytes
d93b771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""

EfficientNet-B3 Model Loader for NIH Chest X-ray Classification

This module provides functions to load and use the EfficientNet-B3 model

trained on the NIH ChestX-ray14 dataset with advanced techniques.

"""

import os
import torch
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights
import torchvision.transforms as transforms
import numpy as np
import cv2
from PIL import Image

# Disease labels - same order as in the original NIH dataset
DISEASE_LIST = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
    'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
    'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
]


# Define the model architecture (DenseNet121 base with custom classifier)
class CustomEfficientNetB3(nn.Module):
    def __init__(self, num_classes=14):
        super().__init__()
        base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
        self.features = base_model.features
        self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        return self.classifier(x)


# Image preprocessing function
def preprocess_image(img):
    """

    Preprocess an image for the model



    Args:

        img: PIL Image or numpy array



    Returns:

        torch.Tensor: Preprocessed image tensor

    """
    # Convert to PIL Image if it's a numpy array
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)

    # Ensure image is in RGB mode
    img = img.convert('RGB')

    # Define preprocessing transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    return transform(img)


def load_model(model_path, device='cpu'):
    """

    Load the model from a checkpoint file



    Args:

        model_path: Path to the model checkpoint

        device: Device to load the model on ('cpu' or 'cuda')



    Returns:

        model: Loaded model

    """
    # Create model architecture
    model = CustomEfficientNetB3(num_classes=14)

    # Load weights with error handling for different PyTorch versions
    try:
        # Try loading with weights_only=False first (for PyTorch <2.6)
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            # If checkpoint contains a state_dict key
            model.load_state_dict(checkpoint['state_dict'])
        else:
            # If checkpoint is the state_dict itself
            model.load_state_dict(checkpoint)
    except Exception as e:
        print(f"Error loading with weights_only=False: {e}")
        try:
            # Add numpy scalar to safe globals as fallback
            import torch.serialization
            import numpy as np
            torch.serialization.add_safe_globals([np.core.multiarray.scalar])

            # Try loading with default settings
            checkpoint = torch.load(model_path, map_location=device)
            if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint)
        except Exception as e2:
            raise RuntimeError(f"Failed to load model: {e2}")

    model.to(device)
    model.eval()
    return model


def predict(model, img_tensor, device='cpu'):
    """

    Make predictions with the model



    Args:

        model: Model

        img_tensor: Preprocessed image tensor

        device: Device to run inference on



    Returns:

        dict: Dictionary mapping disease names to probabilities

    """
    with torch.no_grad():
        output = model(img_tensor.unsqueeze(0).to(device))
        probs = torch.sigmoid(output[0]).cpu().numpy()

    # Create dictionary of results
    results = {disease: float(prob) for disease, prob in zip(DISEASE_LIST, probs)}
    return results


def register_hooks(model):
    """

    Register hooks for Grad-CAM



    Args:

        model: Model



    Returns:

        activation_dict: Dictionary to store activations

        gradient_dict: Dictionary to store gradients

    """
    activation_dict = {}
    gradient_dict = {}

    def get_activation(name):
        def hook(module, input, output):
            activation_dict[name] = output

        return hook

    def get_gradient(name):
        def hook(grad):
            gradient_dict[name] = grad

        return hook

    # Register hooks on the last dense block for better feature visualization
    target_layer = model.features[-1]
    target_layer.register_forward_hook(get_activation('target_layer'))

    return activation_dict, gradient_dict


def compute_gradcam(model, img_tensor, target_class_idx=None, device='cpu'):
    """

    Compute Grad-CAM for the model



    Args:

        model: Model

        img_tensor: Preprocessed image tensor

        target_class_idx: Index of the target class (if None, uses the highest probability class)

        device: Device to run on



    Returns:

        numpy.ndarray: Grad-CAM heatmap (224x224)

    """
    # Register hooks
    activation_dict, gradient_dict = register_hooks(model)

    # Clone the tensor to avoid modifying the original
    img_tensor_for_gradcam = img_tensor.clone().to(device)
    img_tensor_for_gradcam.requires_grad_(True)

    # Forward pass
    model.zero_grad()
    output = model(img_tensor_for_gradcam.unsqueeze(0))

    # If target_class is None, use the class with the highest score
    if target_class_idx is None:
        target_class_idx = torch.argmax(output).item()

    # Target for backprop
    one_hot = torch.zeros_like(output)
    one_hot[0, target_class_idx] = 1

    # Backward pass with retain_graph to avoid "backward through the graph a second time" error
    output.backward(gradient=one_hot, retain_graph=True)

    # Get activations
    activations = activation_dict['target_layer']

    # Try different approaches to get gradients
    try:
        # First try: direct gradient access if activation has grad_fn
        if activations.grad_fn is not None:
            gradients = torch.autograd.grad(output[:, target_class_idx].sum(),
                                            activations,
                                            retain_graph=True)[0]
    except Exception as e:
        print(f"First gradient approach failed: {e}")

        try:
            # Second try: use the gradient from the model's parameters
            # Find the appropriate convolutional layer's parameters
            target_layer = model.features[-1]

            # Get gradients from parameters
            params = [p for p in target_layer.parameters() if p.requires_grad]
            if params:
                # Use the gradient of the first parameter as a proxy
                gradients = params[0].grad

                # Reshape if needed to match activation shape
                if gradients is not None and gradients.shape != activations.shape:
                    # This is a fallback that might not be accurate but better than nothing
                    print("Warning: Gradient shape mismatch, using alternative approach")
                    # Create a dummy gradient of the right shape
                    gradients = torch.ones_like(activations)
        except Exception as e:
            print(f"Second gradient approach failed: {e}")

    # If we still don't have gradients, create a dummy gradient as last resort
    if 'gradients' not in locals() or gradients is None:
        print("Warning: Could not compute gradients, using dummy gradients")
        gradients = torch.ones_like(activations)

    # Use global average pooling with absolute values for better feature highlighting
    # This helps focus on the magnitude of importance rather than just direction
    pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3])

    # Weight activation maps with gradients
    for i in range(activations.size(1)):
        activations[:, i, :, :] *= pooled_gradients[i]

    # Sum along channels for final heatmap
    heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy()

    # ReLU on the heatmap
    heatmap = np.maximum(heatmap, 0)

    # Apply gamma correction to enhance contrast
    gamma = 0.7  # Values less than 1 enhance bright regions
    heatmap = np.power(heatmap, gamma)

    # Normalize heatmap
    if np.max(heatmap) > 0:
        heatmap = heatmap / np.max(heatmap)

    # Apply threshold to remove noise
    # This helps focus on the most important regions
    threshold = 0.2  # Only keep values above 20% of max
    heatmap[heatmap < threshold] = 0

    # Re-normalize after thresholding
    if np.max(heatmap) > 0:
        heatmap = heatmap / np.max(heatmap)

    # Resize to 224x224
    heatmap = cv2.resize(heatmap, (224, 224))

    return heatmap


def apply_gradcam(original_img, heatmap, alpha=0.6):
    """

    Apply Grad-CAM heatmap to the original image



    Args:

        original_img: PIL Image or numpy array

        heatmap: Grad-CAM heatmap

        alpha: Transparency factor



    Returns:

        numpy.ndarray: Image with heatmap overlay

    """
    # Convert to numpy if it's a PIL Image
    if isinstance(original_img, Image.Image):
        original_img = np.array(original_img)

    # Resize original image if needed
    original_img = cv2.resize(original_img, (224, 224))

    # Convert original image to RGB if it's grayscale
    if len(original_img.shape) == 2:
        original_img = np.stack([original_img] * 3, axis=2)
    elif len(original_img.shape) == 3 and original_img.shape[2] == 1:
        original_img = np.concatenate([original_img] * 3, axis=2)

    # Convert heatmap to uint8 before applying median blur
    heatmap_uint8 = np.uint8(heatmap * 255)
    heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7)
    # Convert back to float in range [0,1]
    heatmap = heatmap_blurred.astype(float) / 255.0

    # Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT)

    # Convert to RGB if needed
    if len(original_img.shape) == 3 and original_img.shape[2] == 3:
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    # Create a copy of the original image for overlay
    original_img_float = original_img.astype(float)

    # Superimpose heatmap on original image
    superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5)
    superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

    # Add contour lines for the most significant regions
    # This helps delineate the affected areas more clearly
    binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255
    contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1)

    return superimposed_img