File size: 20,132 Bytes
d93b771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
"""

Chest X-ray Classification App with Multiple Models

Integrates multiple models with unified Grad-CAM visualization

"""

import os
import torch
import numpy as np
import gradio as gr
import torchxrayvision as xrv
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import numpy
import torch.serialization
import matplotlib.pyplot as plt
import cv2
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend

# Import the custom EfficientNet-B3 model from separate module
# This is actually the DannyNet model but renamed for consistency
import efficientnet_b3_custom

torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])

# Define the available models - removed parenthetical descriptors
MODELS = {
    "DenseNet121": "densenet121-res224-nih",
    "EfficientNet-B3": "efficientnet_b3_custom",  # This is the DannyNet model
    "EfficientNet-B3 O": "efficientnet_b3",
    "EfficientNet-B0": "efficientnet_b0"
}

# NIH ChestX-ray14 pathologies
NIH_PATHOLOGIES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
    'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
    'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

# Cache for loaded models
loaded_models = {}

# For Grad-CAM
activation = {}
gradient = {}


# Hook functions for Grad-CAM
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output  # Keep gradient tracking
    return hook


def get_gradient(name):
    def hook(grad):
        gradient[name] = grad
    return hook


def load_model(model_name):
    if model_name in loaded_models:
        return loaded_models[model_name]

    model_type = MODELS[model_name]

    try:
        if model_name == "DenseNet121":
            # Load DenseNet121 from TorchXRayVision
            model = xrv.models.DenseNet(weights="densenet121-res224-nih")
            model.eval()
            loaded_models[model_name] = model
            return model

        elif model_type == "efficientnet_b3_custom":
            # Load the custom EfficientNet-B3 model (actually DannyNet)
            model_path = "dannynet-55-best_model_20250422-211522.pth"
            if os.path.exists(model_path):
                model = efficientnet_b3_custom.load_model(model_path, device='cpu')
                print(f"Successfully loaded EfficientNet-B3 from {model_path}")
                loaded_models[model_name] = model
                return model
            else:
                print(f"Model file not found: {model_path}")
                print("Please place the model file in the same directory as this script.")
                return None

        elif model_type == "efficientnet_b3":
            # Import EfficientNet dynamically to avoid dependency issues
            try:
                from efficientnet_pytorch import EfficientNet
                model = EfficientNet.from_name('efficientnet-b3', num_classes=14)
            except ImportError:
                # Fallback to torchvision if efficientnet_pytorch is not available
                model = models.efficientnet_b3(pretrained=False)
                num_ftrs = model.classifier[1].in_features
                model.classifier[1] = nn.Linear(num_ftrs, 14)

            # Load your trained weights
            model_path = os.path.join("weights", "best_model_b3.pt")
            if os.path.exists(model_path):
                # Explicitly set weights_only=False for PyTorch 2.6+ compatibility
                checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
                if "model_state_dict" in checkpoint:
                    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
                else:
                    model.load_state_dict(checkpoint, strict=False)

                print(f"Successfully loaded EfficientNet-B3 Original from {model_path}")
            else:
                print(f"Model file not found: {model_path}")
                print("Using a new model instance. Please place your trained model in the weights directory.")

        elif model_type == "efficientnet_b0":
            # Import EfficientNet dynamically to avoid dependency issues
            try:
                from efficientnet_pytorch import EfficientNet
                model = EfficientNet.from_name('efficientnet-b0', num_classes=14)
            except ImportError:
                # Fallback to torchvision if efficientnet_pytorch is not available
                model = models.efficientnet_b0(pretrained=False)
                num_ftrs = model.classifier[1].in_features
                model.classifier[1] = nn.Linear(num_ftrs, 14)

            # Load your trained weights
            model_path = os.path.join("weights", "best_model_b0.pt")
            if os.path.exists(model_path):
                # Explicitly set weights_only=False for PyTorch 2.6+ compatibility
                checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
                if "model_state_dict" in checkpoint:
                    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
                else:
                    model.load_state_dict(checkpoint, strict=False)

                print(f"Successfully loaded EfficientNet-B0 from {model_path}")
            else:
                print(f"Model file not found: {model_path}")
                print("Using a new model instance. Please place your trained model in the weights directory.")

        model.eval()
        loaded_models[model_name] = model
        return model

    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return None


def preprocess_image_densenet(img):
    """Preprocess an image for the DenseNet model."""
    # Convert to grayscale if it's a color image
    if len(img.shape) > 2:
        img = img.mean(2)

    # Normalize the image
    img = xrv.datasets.normalize(img, 255)

    # Add channel dimension
    if len(img.shape) == 2:
        img = img[None, ...]

    return img


def preprocess_image_efficientnet(img, img_size=224):
    """Preprocess an image for the EfficientNet models."""
    # Convert to PIL Image if it's a numpy array
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)

    # Ensure image is in RGB mode
    img = img.convert('RGB')

    # Define preprocessing transforms
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Apply transforms
    img_tensor = transform(img)

    return img_tensor


# Unified Grad-CAM implementation for all models
def compute_unified_gradcam(model, model_name, img_tensor, target_class_idx):
    """

    Compute Grad-CAM using a unified approach for all models

    This ensures consistent visualization style across all models

    """
    # For the custom EfficientNet-B3 model (DannyNet), use its dedicated function
    if MODELS[model_name] == "efficientnet_b3_custom":
        return efficientnet_b3_custom.compute_gradcam(model, img_tensor, target_class_idx)

    # For other models, implement a similar approach to ensure consistent results
    # Register hooks
    activation.clear()
    gradient.clear()

    # Select appropriate target layer based on model type
    if model_name == "DenseNet121":
        target_layer = model.features.denseblock3
    elif "EfficientNet" in model_name:
        if hasattr(model, '_blocks'):
            middle_idx = len(model._blocks) // 2
            target_layer = model._blocks[middle_idx]
        else:
            middle_idx = len(model.features) // 2
            target_layer = model.features[middle_idx]
    else:
        # Default to last feature layer
        target_layer = model.features[-1]

    # Register forward hook
    handle = target_layer.register_forward_hook(get_activation('target_layer'))

    # Ensure input tensor requires gradients
    img_tensor_for_gradcam = img_tensor.clone().requires_grad_(True)

    # Forward pass
    model.zero_grad()

    # Handle different model output formats
    if model_name == "DenseNet121":
        output = model(img_tensor_for_gradcam.unsqueeze(0))
    else:
        output = model(img_tensor_for_gradcam.unsqueeze(0))
        output = torch.sigmoid(output)

    # Target for backprop
    if target_class_idx is not None:
        score = output[0, target_class_idx]
    else:
        score, _ = output.max(dim=1)
        score = score[0]

    # Backward pass with retain_graph to avoid errors
    score.backward(retain_graph=True)

    # Clean up hook
    handle.remove()

    # Get activations
    if 'target_layer' not in activation:
        print("No activation captured")
        return None

    activations = activation['target_layer']

    # Try different approaches to get gradients
    try:
        # Get gradients using autograd
        gradients = torch.autograd.grad(score, activations,
                                       create_graph=True, retain_graph=True)[0]
    except Exception as e:
        print(f"Gradient calculation failed: {e}")
        # Create dummy gradients as fallback
        gradients = torch.ones_like(activations)

    # Use global average pooling with absolute values for better feature highlighting
    pooled_gradients = torch.mean(torch.abs(gradients), dim=[0, 2, 3])

    # Weight activation maps with gradients
    for i in range(activations.size(1)):
        activations[:, i, :, :] *= pooled_gradients[i]

    # Sum along channels for final heatmap
    heatmap = torch.sum(activations, dim=1).squeeze().cpu().detach().numpy()

    # ReLU on the heatmap
    heatmap = np.maximum(heatmap, 0)

    # Apply gamma correction to enhance contrast
    gamma = 0.7  # Values less than 1 enhance bright regions
    heatmap = np.power(heatmap, gamma)

    # Normalize heatmap
    if np.max(heatmap) > 0:
        heatmap = heatmap / np.max(heatmap)

    # Apply threshold to remove noise
    threshold = 0.2  # Only keep values above 20% of max
    heatmap[heatmap < threshold] = 0

    # Re-normalize after thresholding
    if np.max(heatmap) > 0:
        heatmap = heatmap / np.max(heatmap)

    # Resize to 224x224
    heatmap = cv2.resize(heatmap, (224, 224))

    return heatmap


# Unified Grad-CAM overlay function for all models
def apply_unified_gradcam(original_img, heatmap, alpha=0.6):
    """

    Apply Grad-CAM heatmap to the original image using a unified approach

    This ensures consistent visualization style across all models

    """
    # Convert to numpy if it's a PIL Image
    if isinstance(original_img, Image.Image):
        original_img = np.array(original_img)

    # Resize original image to 224x224
    original_img = cv2.resize(original_img, (224, 224))

    # Convert original image to RGB if it's grayscale
    if len(original_img.shape) == 2:
        original_img = np.stack([original_img] * 3, axis=2)
    elif len(original_img.shape) == 3 and original_img.shape[2] == 1:
        original_img = np.concatenate([original_img] * 3, axis=2)

    # Convert heatmap to uint8 before applying median blur
    heatmap_uint8 = np.uint8(heatmap * 255)
    heatmap_blurred = cv2.medianBlur(heatmap_uint8, 7)
    # Convert back to float in range [0,1]
    heatmap = heatmap_blurred.astype(float) / 255.0

    # Apply colormap to heatmap - Use COLORMAP_HOT for better medical visualization
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_HOT)

    # Convert to RGB if needed
    if len(original_img.shape) == 3 and original_img.shape[2] == 3:
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    # Create a copy of the original image for overlay
    original_img_float = original_img.astype(float)

    # Superimpose heatmap on original image
    superimposed_img = heatmap_colored * alpha + original_img_float * (1 - alpha * 0.5)
    superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

    # Add contour lines for the most significant regions
    binary_heatmap = (heatmap > 0.5).astype(np.uint8) * 255
    contours, _ = cv2.findContours(binary_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(superimposed_img, contours, -1, (255, 255, 255), 1)

    return superimposed_img


def predict_with_gradcam(image, model_name, confidence_threshold=0.5):
    """Make predictions and generate Grad-CAM visualization."""
    if image is None:
        return None, "Please upload an image."

    try:
        # Create weights directory if it doesn't exist (for EfficientNet models)
        if "EfficientNet" in model_name:
            os.makedirs("weights", exist_ok=True)

        # Load the model
        model = load_model(model_name)

        if model is None:
            return None, f"Failed to load model {model_name}. Please check the console for details."

        # Save original image for visualization
        original_img = np.array(image).copy()

        # Process based on model type
        if model_name == "DenseNet121":
            # Read and preprocess the image for DenseNet from TorchXRayVision
            img = np.array(Image.fromarray(image).convert('RGB'))
            img_processed = preprocess_image_densenet(img)

            # Create transforms
            transform = transforms.Compose([
                xrv.datasets.XRayCenterCrop(),
                xrv.datasets.XRayResizer(224)
            ])

            # Apply transforms
            img_processed = transform(img_processed)

            # Convert to tensor
            img_tensor = torch.from_numpy(img_processed)

            # Make prediction
            with torch.no_grad():
                output = model(img_tensor.unsqueeze(0))
                probabilities = output.squeeze().numpy()

            # Create dictionary of results
            results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)}

        elif MODELS[model_name] == "efficientnet_b3_custom":
            # Use the custom EfficientNet-B3 module for preprocessing and prediction
            img_tensor = efficientnet_b3_custom.preprocess_image(image)

            # Make prediction
            results = efficientnet_b3_custom.predict(model, img_tensor)

        else:  # Other EfficientNet models
            # Preprocess the image for EfficientNet
            img_tensor = preprocess_image_efficientnet(image)

            # Make prediction
            with torch.no_grad():
                output = model(img_tensor.unsqueeze(0))
                probabilities = torch.sigmoid(output).squeeze().numpy()

            # Create dictionary of results
            results = {pathology: float(prob) for pathology, prob in zip(NIH_PATHOLOGIES, probabilities)}

        # Sort results by probability (descending)
        sorted_results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))

        # Get top pathologies above threshold
        top_pathologies = [p for p, prob in sorted_results.items() if prob >= confidence_threshold]

        # Generate Grad-CAM for top pathologies
        gradcam_img = None

        if top_pathologies:
            # Get index of top pathology
            if model_name == "DenseNet121" or "EfficientNet" in model_name:
                top_pathology = top_pathologies[0]
                target_idx = NIH_PATHOLOGIES.index(top_pathology)

            # Compute Grad-CAM using unified approach
            heatmap = compute_unified_gradcam(model, model_name, img_tensor, target_idx)

            if heatmap is not None:
                # Apply Grad-CAM overlay using unified approach
                gradcam_img = apply_unified_gradcam(original_img, heatmap)

        # If no Grad-CAM was generated, return the original image
        if gradcam_img is None:
            gradcam_img = original_img

        # Format results for display
        result_html = "<h3>Detected Conditions:</h3><ul>"
        detected_count = 0

        for pathology, prob in sorted_results.items():
            if prob >= confidence_threshold:
                detected_count += 1
                result_html += f"<li><b>{pathology}</b>: {prob:.4f} ({prob * 100:.1f}%)</li>"

        if detected_count == 0:
            result_html += "<li>No conditions detected above the confidence threshold.</li>"

        result_html += "</ul>"

        # Add a section for all probabilities
        result_html += "<h3>All Probabilities:</h3><ul>"
        for pathology, prob in sorted_results.items():
            if pathology in top_pathologies:
                result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%) - <span style='color:red'>Used for Grad-CAM</span></li>"
            else:
                result_html += f"<li>{pathology}: {prob:.4f} ({prob * 100:.1f}%)</li>"
        result_html += "</ul>"

        # Add explanation about Grad-CAM
        result_html += "<h3>About Grad-CAM:</h3>"
        result_html += "<p>Grad-CAM highlights regions that influenced the model's prediction for the detected conditions. "
        result_html += "Red/yellow areas indicate regions of high importance for the diagnosis.</p>"

        # Add explanation about the improved visualization
        result_html += "<p><b>Improved Visualization:</b> This enhanced Grad-CAM uses medical imaging-specific techniques to better highlight clinically relevant regions. "
        if len(top_pathologies) > 1:
            result_html += f"The visualization combines information from {len(top_pathologies)} detected conditions for a more comprehensive view.</p>"
        else:
            result_html += "The visualization focuses on the most significant condition detected in the X-ray.</p>"

        return gradcam_img, result_html

    except Exception as e:
        import traceback
        traceback_str = traceback.format_exc()
        print(f"Error processing image: {str(e)}")
        print(traceback_str)
        return None, f"Error processing image: {str(e)}"


# Create the Gradio interface
with gr.Blocks(title="Chest X-ray Disease Classifier with Improved Grad-CAM") as demo:
    gr.Markdown("# Chest X-ray Disease Classifier with Improved Grad-CAM")
    gr.Markdown("Upload a chest X-ray image and select a model to detect potential conditions. The enhanced Grad-CAM visualization will highlight clinically relevant regions influencing the diagnosis.")

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Upload Chest X-ray Image", type="numpy")
            model_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value="EfficientNet-B3",  # Default to the DannyNet model
                label="Select Model"
            )
            confidence = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                value=0.5,
                step=0.05,
                label="Confidence Threshold"
            )
            submit_button = gr.Button("Analyze X-ray")

        with gr.Column(scale=2):
            # REMOVED: Removed the original image display as requested by the user
            gradcam_image = gr.Image(label="Grad-CAM Visualization")
            output_text = gr.HTML(label="Results")

    submit_button.click(
        fn=predict_with_gradcam,
        inputs=[input_image, model_dropdown, confidence],
        outputs=[gradcam_image, output_text]
    )


# Launch the app
if __name__ == "__main__":
    demo.launch()