Spaces:
Sleeping
Sleeping
from torch.nn import functional as F | |
import cv2 | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def denormalize(img): | |
mean = (0.49139968, 0.48215841, 0.44653091) | |
std = (0.24703223, 0.24348513, 0.26158784) | |
img = img.cpu().numpy().astype(dtype=np.float32) | |
for i in range(img.shape[0]): | |
img[i] = (img[i]*std[i])+mean[i] | |
return np.transpose(img, (1,2,0)) | |
class GradCAM: | |
""" Class for extracting activations and | |
registering gradients from targetted intermediate layers | |
target_layers = list of convolution layer index as shown in summary | |
""" | |
def __init__(self, model, candidate_layers=None): | |
def save_fmaps(key): | |
def forward_hook(module, input, output): | |
self.fmap_pool[key] = output.detach() | |
return forward_hook | |
def save_grads(key): | |
def backward_hook(module, grad_in, grad_out): | |
self.grad_pool[key] = grad_out[0].detach() | |
return backward_hook | |
self.device = next(model.parameters()).device | |
self.model = model | |
self.handlers = [] # a set of hook function handlers | |
self.fmap_pool = {} | |
self.grad_pool = {} | |
self.candidate_layers = candidate_layers # list | |
for name, module in self.model.named_modules(): | |
if self.candidate_layers is None or name in self.candidate_layers: | |
self.handlers.append(module.register_forward_hook(save_fmaps(name))) | |
self.handlers.append(module.register_backward_hook(save_grads(name))) | |
def _encode_one_hot(self, ids): | |
one_hot = torch.zeros_like(self.nll).to(self.device) | |
print(one_hot.shape) | |
one_hot.scatter_(1, ids, 1.0) | |
return one_hot | |
def forward(self, image): | |
self.image_shape = image.shape[2:] # HxW | |
self.nll = self.model(image) | |
#self.probs = F.softmax(self.logits, dim=1) | |
return self.nll.sort(dim=1, descending=True) # ordered results | |
def backward(self, ids): | |
""" | |
Class-specific backpropagation | |
""" | |
one_hot = self._encode_one_hot(ids) | |
self.model.zero_grad() | |
self.nll.backward(gradient=one_hot, retain_graph=True) | |
def remove_hook(self): | |
""" | |
Remove all the forward/backward hook functions | |
""" | |
for handle in self.handlers: | |
handle.remove() | |
def _find(self, pool, target_layer): | |
if target_layer in pool.keys(): | |
return pool[target_layer] | |
else: | |
raise ValueError("Invalid layer name: {}".format(target_layer)) | |
def generate(self, target_layer): | |
fmaps = self._find(self.fmap_pool, target_layer) | |
grads = self._find(self.grad_pool, target_layer) | |
weights = F.adaptive_avg_pool2d(grads, 1) | |
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) | |
gcam = F.relu(gcam) | |
# need to capture image size duign forward pass | |
gcam = F.interpolate( | |
gcam, self.image_shape, mode="bilinear", align_corners=False | |
) | |
# scale output between 0,1 | |
B, C, H, W = gcam.shape | |
gcam = gcam.view(B, -1) | |
gcam -= gcam.min(dim=1, keepdim=True)[0] | |
gcam /= gcam.max(dim=1, keepdim=True)[0] | |
gcam = gcam.view(B, C, H, W) | |
return gcam | |
def generate_gradcam(misclassified_images, model, target_layers,device): | |
images=[] | |
labels=[] | |
for i, (img, pred, correct) in enumerate(misclassified_images): | |
images.append(img) | |
labels.append(correct) | |
model.eval() | |
# map input to device | |
images = torch.stack(images).to(device) | |
# set up grad cam | |
gcam = GradCAM(model, target_layers) | |
# forward pass | |
probs, ids = gcam.forward(images) | |
# outputs agaist which to compute gradients | |
ids_ = torch.LongTensor(labels).view(len(images),-1).to(device) | |
# backward pass | |
gcam.backward(ids=ids_) | |
layers = [] | |
for i in range(len(target_layers)): | |
target_layer = target_layers[i] | |
print("Generating Grad-CAM @{}".format(target_layer)) | |
# Grad-CAM | |
layers.append(gcam.generate(target_layer=target_layer)) | |
# remove hooks when done | |
gcam.remove_hook() | |
return layers, probs, ids | |
def plot_gradcam_images(gcam_layers, target_layers, classes, image_size,predicted, misclassified_images): | |
images=[] | |
labels=[] | |
for i, (img, pred, correct) in enumerate(misclassified_images): | |
images.append(img) | |
labels.append(correct) | |
c = len(images)+1 | |
r = len(target_layers)+2 | |
fig = plt.figure(figsize=(60,30)) | |
fig.subplots_adjust(hspace=0.01, wspace=0.01) | |
ax = plt.subplot(r, c, 1) | |
ax.text(0.3,-0.5, "INPUT", fontsize=28) | |
plt.axis('off') | |
for i in range(len(target_layers)): | |
target_layer = target_layers[i] | |
ax = plt.subplot(r, c, c*(i+1)+1) | |
ax.text(0.3,-0.5, target_layer, fontsize=28) | |
plt.axis('off') | |
for j in range(len(images)): | |
img = np.uint8(255 * denormalize(images[j].view(image_size))) | |
if i==0: | |
ax = plt.subplot(r, c, j+2) | |
ax.text(0, 0.2, f"actual: {classes[labels[j]]} \npred: {classes[predicted[j][0]]}", fontsize=18) | |
plt.axis('off') | |
plt.subplot(r, c, c+j+2) | |
plt.imshow(img) | |
plt.axis('off') | |
heatmap = 1-gcam_layers[i][j].cpu().numpy()[0] # reverse the color map | |
heatmap = np.uint8(255 * heatmap) | |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
superimposed_img = cv2.resize(cv2.addWeighted(img, 0.5, heatmap, 0.5, 0), (128,128)) | |
plt.subplot(r, c, (i+2)*c+j+2) | |
plt.imshow(superimposed_img, interpolation='bilinear') | |
plt.axis('off') | |
plt.show() |