File size: 4,708 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
gradcam visualisation for each GAN class
@author: Tu Bui @surrey.ac.uk
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import inspect
import argparse
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw, ImageFont
import torch
import torchvision
from torch.autograd import Function
import torch.nn.functional as F


def show_cam_on_image(img, cam, cmap='jet'):
    """
    Args:
    img     PIL image (H,W,3)
    cam     heatmap (H, W), range [0,1]
    Returns:
            PIL image with heatmap applied.
    """
    cm = plt.get_cmap(cmap)
    cam = cm(cam)[...,:3]  # RGB [0,1]
    cam = np.array(img, dtype=np.float32)/255. + cam 
    cam /= cam.max()
    cam = np.uint8(cam*255)
    return Image.fromarray(cam)


class HookedModel(object):
    def __init__(self, model, feature_layer_name):
        self.model = model 
        self.feature_trees = feature_layer_name.split('.')

    def __call__(self, x):
        x = feedforward(x, self.model, self.feature_trees)
        return x 


def feedforward(x, module, layer_names):
    for name, submodule in module._modules.items():
        # print(f'Forwarding {name} ...')
        if name == layer_names[0]:
            if len(layer_names) == 1:  # leaf node reached
                # print(f'    Hook {name}')
                x = submodule(x)
                x.register_hook(save_gradients)
                save_features(x)
            else:
                # print(f'  Stepping into {name}:')
                x = feedforward(x, submodule, layer_names[1:])
        else:
            x = submodule(x)
            if name == 'avgpool':  # specific for resnet50
                x = x.view(x.size(0), -1)
    return x


basket = dict(grads=[], feature_maps=[])  # global variable to hold the gradients and output features of the layers of interest

def empty_basket():
    basket = dict(grads=[], feature_maps=[])

def save_gradients(grad):
    basket['grads'].append(grad)

def save_features(feat):
    basket['feature_maps'].append(feat)


class GradCam(object):
    def __init__(self, model, feature_layer_name, use_cuda=True):
        self.model = model 
        self.hooked_model = HookedModel(model, feature_layer_name)
        self.cuda = use_cuda
        if self.cuda:
            self.model = model.cuda()
        self.model.eval()

    def __call__(self, x, target, act=None):
        empty_basket()
        target = torch.as_tensor(target, dtype=torch.float)
        if self.cuda:
            x = x.cuda()
            target = target.cuda()
        z = self.hooked_model(x)
        if act is not None:
            z = act(z)
        criteria = F.cosine_similarity(z, target)
        self.model.zero_grad()
        criteria.backward(retain_graph=True)
        gradients = [grad.cpu().data.numpy() for grad in basket['grads'][::-1]]  # gradients appear in reversed order
        feature_maps = [feat.cpu().data.numpy() for feat in basket['feature_maps']]
        cams = []
        for feat, grad in zip(feature_maps, gradients):
            # feat and grad have shape (1, C, H, W)
            weight = np.mean(grad, axis=(2,3), keepdims=True)[0]  # (C,1,1)
            cam = np.sum(weight * feat[0], axis=0)  # (H,w)
            cam = cv2.resize(cam, x.shape[2:])
            cam = cam - np.min(cam)
            cam = cam / (np.max(cam) + np.finfo(np.float32).eps)
            cams.append(cam)
        cams = np.array(cams).mean(axis=0)  # (H,W)
        return cams


def gradcam_demo():
    from torchvision import transforms
    model = torchvision.models.resnet50(pretrained=True)
    model.eval()
    gradcam = GradCam(model, 'layer4.2', True)
    tform = [
                transforms.Resize((224, 224)),
                # transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
    preprocess = transforms.Compose(tform)
    im0 = Image.open('/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet/examples/catdog.jpg').convert('RGB')
    im = preprocess(im0).unsqueeze(0)
    target = np.zeros((1,1000), dtype=np.float32)
    target[0, 285] = 1  # cat
    cam = gradcam(im, target)

    im0 = tform[0](im0)
    out = show_cam_on_image(im0, cam)
    out.save('test.jpg')
    print('done')


def make_target_vector(nclass, target_class_id):
    out = np.zeros((1, nclass), dtype=np.float32)
    out[0, target_class_id] = 1
    return out 



if __name__ == '__main__':
    gradcam_demo()