import os import torch import numpy as np from PIL import Image from utils import * from tqdm import tqdm from gradcam import GradCAM, GradCAMpp from overlay_image import overlay_numpy DISEASES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia' ] class GradCamGenerator: def __init__(self, model_path, layer, overlay=False): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = self.load_model(model_path) self.layer = layer self.overlay = overlay # Overlay GC heatmaps with image self.layer_module = self.model.get_submodule(layer) self.gc_model2 = GradCAM(self.model, self.layer_module) # , device_ids=self.device) def load_model(self, model_path, print_net=False): checkpoint = torch.load(model_path, map_location=self.device) model = checkpoint['model'] self.set_inplace_False(model) if print_net: print(model) return model def set_inplace_False(self, module): for layer in module._modules.values(): if isinstance(layer, nn.ReLU): layer.inplace = False self.set_inplace_False(layer) def generate_grad_cam(self, path): img = self.pil_loader(path, 3) input_image = self.transform_pil_to_tensor(img) tclass = self.target_from_path(path) #tmp_pred = self.model(input_image) grayscale_cams = self.gc_model2(input=input_image, class_idx=tclass) attribution = 255*grayscale_cams[0].detach().cpu().numpy().squeeze() attribution /= attribution.max() if self.overlay: overlay_numpy(img, attribution, path) #print() return attribution def target_from_path(self, path): disease = path.split('/')[-2] indx = DISEASES.index(disease) if disease!='No Finding' else 0 return torch.tensor(indx, device=self.device) def save_img(self, image, input_path): gc_filename = input_path[:-4]+'_gc'+input_path[-4:] image_PIL = Image.fromarray(image).convert('L') image_PIL.save(gc_filename) def transform_pil_to_tensor(self, pil_image): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean, std) ]) tensor = transform(pil_image).to(self.device) return tensor.unsqueeze(0) def pil_loader(self, path, n_channels): with open(path, 'rb') as f: img = Image.open(f) if n_channels == 1: return img.convert('L') elif n_channels == 3: return img.convert('RGB') else: raise ValueError('Invalid value for parameter n_channels!') def create_GC_from_folder(path, classifier='checkpoint', layer_name='features.norm5', overlay=True, override_gc=True): GC = GradCamGenerator(classifier, layer_name, overlay=overlay) folds = ['data/' + i for i in os.listdir(path) if 'No Finding' not in i][10:] for cf in folds: files = [cf+'/'+f for f in os.listdir(cf) if ('_gc' not in f and 'overlay' not in f and (not os.path.exists(cf+'/'+f[:-4]+'_gc.png') or override_gc) and (not os.path.exists(cf+'/'+f[:-4]+'_overlay.png') or not overlay))] for cfil in tqdm(files): GC.generate_grad_cam(cfil)