test / tools /gradcam.py
Tu Bui
first commit
6142a25
raw
history blame
4.71 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
gradcam visualisation for each GAN class
@author: Tu Bui @surrey.ac.uk
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import inspect
import argparse
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
import torchvision
from torch.autograd import Function
import torch.nn.functional as F
def show_cam_on_image(img, cam, cmap='jet'):
"""
Args:
img PIL image (H,W,3)
cam heatmap (H, W), range [0,1]
Returns:
PIL image with heatmap applied.
"""
cm = plt.get_cmap(cmap)
cam = cm(cam)[...,:3] # RGB [0,1]
cam = np.array(img, dtype=np.float32)/255. + cam
cam /= cam.max()
cam = np.uint8(cam*255)
return Image.fromarray(cam)
class HookedModel(object):
def __init__(self, model, feature_layer_name):
self.model = model
self.feature_trees = feature_layer_name.split('.')
def __call__(self, x):
x = feedforward(x, self.model, self.feature_trees)
return x
def feedforward(x, module, layer_names):
for name, submodule in module._modules.items():
# print(f'Forwarding {name} ...')
if name == layer_names[0]:
if len(layer_names) == 1: # leaf node reached
# print(f' Hook {name}')
x = submodule(x)
x.register_hook(save_gradients)
save_features(x)
else:
# print(f' Stepping into {name}:')
x = feedforward(x, submodule, layer_names[1:])
else:
x = submodule(x)
if name == 'avgpool': # specific for resnet50
x = x.view(x.size(0), -1)
return x
basket = dict(grads=[], feature_maps=[]) # global variable to hold the gradients and output features of the layers of interest
def empty_basket():
basket = dict(grads=[], feature_maps=[])
def save_gradients(grad):
basket['grads'].append(grad)
def save_features(feat):
basket['feature_maps'].append(feat)
class GradCam(object):
def __init__(self, model, feature_layer_name, use_cuda=True):
self.model = model
self.hooked_model = HookedModel(model, feature_layer_name)
self.cuda = use_cuda
if self.cuda:
self.model = model.cuda()
self.model.eval()
def __call__(self, x, target, act=None):
empty_basket()
target = torch.as_tensor(target, dtype=torch.float)
if self.cuda:
x = x.cuda()
target = target.cuda()
z = self.hooked_model(x)
if act is not None:
z = act(z)
criteria = F.cosine_similarity(z, target)
self.model.zero_grad()
criteria.backward(retain_graph=True)
gradients = [grad.cpu().data.numpy() for grad in basket['grads'][::-1]] # gradients appear in reversed order
feature_maps = [feat.cpu().data.numpy() for feat in basket['feature_maps']]
cams = []
for feat, grad in zip(feature_maps, gradients):
# feat and grad have shape (1, C, H, W)
weight = np.mean(grad, axis=(2,3), keepdims=True)[0] # (C,1,1)
cam = np.sum(weight * feat[0], axis=0) # (H,w)
cam = cv2.resize(cam, x.shape[2:])
cam = cam - np.min(cam)
cam = cam / (np.max(cam) + np.finfo(np.float32).eps)
cams.append(cam)
cams = np.array(cams).mean(axis=0) # (H,W)
return cams
def gradcam_demo():
from torchvision import transforms
model = torchvision.models.resnet50(pretrained=True)
model.eval()
gradcam = GradCam(model, 'layer4.2', True)
tform = [
transforms.Resize((224, 224)),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
preprocess = transforms.Compose(tform)
im0 = Image.open('/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet/examples/catdog.jpg').convert('RGB')
im = preprocess(im0).unsqueeze(0)
target = np.zeros((1,1000), dtype=np.float32)
target[0, 285] = 1 # cat
cam = gradcam(im, target)
im0 = tform[0](im0)
out = show_cam_on_image(im0, cam)
out.save('test.jpg')
print('done')
def make_target_vector(nclass, target_class_id):
out = np.zeros((1, nclass), dtype=np.float32)
out[0, target_class_id] = 1
return out
if __name__ == '__main__':
gradcam_demo()