s12 / utils /gradcam.py
srikanthp07's picture
Upload 27 files
9022436
raw
history blame contribute delete
No virus
5.79 kB
from torch.nn import functional as F
import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
def denormalize(img):
mean = (0.49139968, 0.48215841, 0.44653091)
std = (0.24703223, 0.24348513, 0.26158784)
img = img.cpu().numpy().astype(dtype=np.float32)
for i in range(img.shape[0]):
img[i] = (img[i]*std[i])+mean[i]
return np.transpose(img, (1,2,0))
class GradCAM:
""" Class for extracting activations and
registering gradients from targetted intermediate layers
target_layers = list of convolution layer index as shown in summary
"""
def __init__(self, model, candidate_layers=None):
def save_fmaps(key):
def forward_hook(module, input, output):
self.fmap_pool[key] = output.detach()
return forward_hook
def save_grads(key):
def backward_hook(module, grad_in, grad_out):
self.grad_pool[key] = grad_out[0].detach()
return backward_hook
self.device = next(model.parameters()).device
self.model = model
self.handlers = [] # a set of hook function handlers
self.fmap_pool = {}
self.grad_pool = {}
self.candidate_layers = candidate_layers # list
for name, module in self.model.named_modules():
if self.candidate_layers is None or name in self.candidate_layers:
self.handlers.append(module.register_forward_hook(save_fmaps(name)))
self.handlers.append(module.register_backward_hook(save_grads(name)))
def _encode_one_hot(self, ids):
one_hot = torch.zeros_like(self.nll).to(self.device)
print(one_hot.shape)
one_hot.scatter_(1, ids, 1.0)
return one_hot
def forward(self, image):
self.image_shape = image.shape[2:] # HxW
self.nll = self.model(image)
#self.probs = F.softmax(self.logits, dim=1)
return self.nll.sort(dim=1, descending=True) # ordered results
def backward(self, ids):
"""
Class-specific backpropagation
"""
one_hot = self._encode_one_hot(ids)
self.model.zero_grad()
self.nll.backward(gradient=one_hot, retain_graph=True)
def remove_hook(self):
"""
Remove all the forward/backward hook functions
"""
for handle in self.handlers:
handle.remove()
def _find(self, pool, target_layer):
if target_layer in pool.keys():
return pool[target_layer]
else:
raise ValueError("Invalid layer name: {}".format(target_layer))
def generate(self, target_layer):
fmaps = self._find(self.fmap_pool, target_layer)
grads = self._find(self.grad_pool, target_layer)
weights = F.adaptive_avg_pool2d(grads, 1)
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
gcam = F.relu(gcam)
# need to capture image size duign forward pass
gcam = F.interpolate(
gcam, self.image_shape, mode="bilinear", align_corners=False
)
# scale output between 0,1
B, C, H, W = gcam.shape
gcam = gcam.view(B, -1)
gcam -= gcam.min(dim=1, keepdim=True)[0]
gcam /= gcam.max(dim=1, keepdim=True)[0]
gcam = gcam.view(B, C, H, W)
return gcam
def generate_gradcam(misclassified_images, model, target_layers,device):
images=[]
labels=[]
for i, (img, pred, correct) in enumerate(misclassified_images):
images.append(img)
labels.append(correct)
model.eval()
# map input to device
images = torch.stack(images).to(device)
# set up grad cam
gcam = GradCAM(model, target_layers)
# forward pass
probs, ids = gcam.forward(images)
# outputs agaist which to compute gradients
ids_ = torch.LongTensor(labels).view(len(images),-1).to(device)
# backward pass
gcam.backward(ids=ids_)
layers = []
for i in range(len(target_layers)):
target_layer = target_layers[i]
print("Generating Grad-CAM @{}".format(target_layer))
# Grad-CAM
layers.append(gcam.generate(target_layer=target_layer))
# remove hooks when done
gcam.remove_hook()
return layers, probs, ids
def plot_gradcam_images(gcam_layers, target_layers, classes, image_size,predicted, misclassified_images):
images=[]
labels=[]
for i, (img, pred, correct) in enumerate(misclassified_images):
images.append(img)
labels.append(correct)
c = len(images)+1
r = len(target_layers)+2
fig = plt.figure(figsize=(60,30))
fig.subplots_adjust(hspace=0.01, wspace=0.01)
ax = plt.subplot(r, c, 1)
ax.text(0.3,-0.5, "INPUT", fontsize=28)
plt.axis('off')
for i in range(len(target_layers)):
target_layer = target_layers[i]
ax = plt.subplot(r, c, c*(i+1)+1)
ax.text(0.3,-0.5, target_layer, fontsize=28)
plt.axis('off')
for j in range(len(images)):
img = np.uint8(255 * denormalize(images[j].view(image_size)))
if i==0:
ax = plt.subplot(r, c, j+2)
ax.text(0, 0.2, f"actual: {classes[labels[j]]} \npred: {classes[predicted[j][0]]}", fontsize=18)
plt.axis('off')
plt.subplot(r, c, c+j+2)
plt.imshow(img)
plt.axis('off')
heatmap = 1-gcam_layers[i][j].cpu().numpy()[0] # reverse the color map
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.resize(cv2.addWeighted(img, 0.5, heatmap, 0.5, 0), (128,128))
plt.subplot(r, c, (i+2)*c+j+2)
plt.imshow(superimposed_img, interpolation='bilinear')
plt.axis('off')
plt.show()