File size: 3,858 Bytes
e5a19d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
import numpy as np
from PIL import Image
from utils import *
from tqdm import tqdm
from gradcam import GradCAM, GradCAMpp
from overlay_image import overlay_numpy

DISEASES = [ 'Atelectasis',
             'Cardiomegaly',
             'Effusion',
             'Infiltration',
             'Mass',
             'Nodule',
             'Pneumonia',
             'Pneumothorax',
             'Consolidation',
             'Edema',
             'Emphysema',
             'Fibrosis',
             'Pleural Thickening',
             'Hernia' ]

class GradCamGenerator:
    def __init__(self, model_path, layer, overlay=False):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = self.load_model(model_path)
        self.layer = layer
        self.overlay = overlay # Overlay GC heatmaps with image
        self.layer_module = self.model.get_submodule(layer)
        self.gc_model2 = GradCAM(self.model, self.layer_module)  # , device_ids=self.device)

    def load_model(self, model_path, print_net=False):
        checkpoint = torch.load(model_path, map_location=self.device)
        model = checkpoint['model']
        self.set_inplace_False(model)
        if print_net:
            print(model)
        return model

    def set_inplace_False(self, module):
        for layer in module._modules.values():
            if isinstance(layer, nn.ReLU):
                layer.inplace = False
            self.set_inplace_False(layer)

    def generate_grad_cam(self, path):
        img = self.pil_loader(path, 3)
        input_image = self.transform_pil_to_tensor(img)
        tclass = self.target_from_path(path)
        #tmp_pred = self.model(input_image)
        grayscale_cams = self.gc_model2(input=input_image, class_idx=tclass)
        attribution = 255*grayscale_cams[0].detach().cpu().numpy().squeeze()
        attribution /= attribution.max()
        if self.overlay:
            overlay_numpy(img, attribution, path)
            #print()
        return attribution

    def target_from_path(self, path):
        disease = path.split('/')[-2]
        indx = DISEASES.index(disease) if disease!='No Finding' else 0
        return torch.tensor(indx, device=self.device)

    def save_img(self, image, input_path):
        gc_filename = input_path[:-4]+'_gc'+input_path[-4:]
        image_PIL = Image.fromarray(image).convert('L')
        image_PIL.save(gc_filename)

    def transform_pil_to_tensor(self, pil_image):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        tensor = transform(pil_image).to(self.device)
        return tensor.unsqueeze(0)

    def pil_loader(self, path, n_channels):
        with open(path, 'rb') as f:
            img = Image.open(f)
            if n_channels == 1:
                return img.convert('L')
            elif n_channels == 3:
                return img.convert('RGB')
            else:
                raise ValueError('Invalid value for parameter n_channels!')

def create_GC_from_folder(path, classifier='checkpoint', layer_name='features.norm5', overlay=True, override_gc=True):
    GC = GradCamGenerator(classifier, layer_name, overlay=overlay)
    folds = ['data/' + i for i in os.listdir(path) if 'No Finding' not in i][10:]
    for cf in folds:
        files = [cf+'/'+f for f in os.listdir(cf) if ('_gc' not in f and 'overlay' not in f and
                                                      (not os.path.exists(cf+'/'+f[:-4]+'_gc.png') or override_gc) and
                                                      (not os.path.exists(cf+'/'+f[:-4]+'_overlay.png') or not overlay))]
        for cfil in tqdm(files):
            GC.generate_grad_cam(cfil)