File size: 3,529 Bytes
a31cb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4845cde
a31cb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4845cde
a31cb74
 
 
 
 
 
4845cde
a31cb74
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torchvision
import gradio as gr
import torch
import torch.nn.functional as F
import einops
import matplotlib.cm as cm
import numpy as np


def colorize(tensor, cmap_fn=cm.turbo):
    colors = cmap_fn(np.linspace(0, 1, 256))[:, :3]
    colors = torch.from_numpy(colors).to(tensor)
    tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor
    ids = (tensor * 256).clamp(0, 255).long()
    tensor = F.embedding(ids, colors).permute(0, 3, 1, 2)
    tensor = tensor.mul(255).clamp(0, 255).byte()
    return tensor


with open("classes.txt") as f:
    id2label = f.read().splitlines()
    id2label = [c.split(",")[0].lower() for c in id2label]
    label2id = dict([(c, i) for i, c in enumerate(id2label)])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet50(weights="DEFAULT")
model.eval()
model.to(device)

fmap_pool = dict()
grad_pool = dict()


def forward_hook(name):
    def _hook(module, input, output):
        fmap_pool[name] = output.detach()

    return _hook


def backward_hook(name):
    def _hook(module, grad_in, grad_out):
        grad_pool[name] = grad_out[0].detach()

    return _hook


layer_choices = []
for n, m in model.named_children():
    layer_choices.append(n)
    m.register_forward_hook(forward_hook(n))
    m.register_backward_hook(backward_hook(n))

preprocess = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)


def predict(image):
    if image is None:
        return None, None
    image = preprocess(image)[None].to(device)
    probs = model(image).softmax(dim=1)
    result = dict([(c, float(p)) for c, p in zip(id2label, probs[0])])
    return result, None


def gradcam(image_orig, layer, event: gr.SelectData):
    # forward & backward
    target_class = torch.tensor([label2id[event.value]], device=device)
    gradient = F.one_hot(target_class, num_classes=len(label2id)).float()
    image = preprocess(image_orig)[None]
    model(image).backward(gradient=gradient)

    # Grad-CAM
    fmaps = fmap_pool[layer]
    grads = grad_pool[layer]
    weights = F.adaptive_avg_pool2d(grads, 1)
    gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
    gcam = F.relu(gcam)

    # post-process
    gcam = F.interpolate(
        gcam, size=image_orig.shape[:2], mode="bicubic", antialias=True
    )
    gcam -= einops.reduce(gcam, "b c h w -> b () () ()", "min")
    gcam /= einops.reduce(gcam, "b c h w -> b () () ()", "max")
    gcam = colorize(gcam)[0].permute(1, 2, 0).cpu().numpy()
    return gcam


with gr.Blocks(title="Grad-CAM") as demo:
    gr.Markdown(
        """
# Grad-CAM
Unofficial re-implementation of Grad-CAM (https://arxiv.org/abs/1610.02391).<br>
Upload an image and select a prediction to show the Grad-CAM heatmap.
"""
    )

    with gr.Row():
        with gr.Column():
            layer = gr.Dropdown(layer_choices, label="ResNet-50", value="layer4")
            image = gr.Image(label="input", type="numpy")
            label = gr.Label(num_top_classes=10, label="top-10 predictions")
            exmpl = gr.Examples(["cat_dog.png"], image)

        with gr.Column():
            img_out = gr.Image(type="numpy", label="result")

        image.change(predict, inputs=[image], outputs=[label, img_out])
        label.select(gradcam, inputs=[image, layer], outputs=[img_out])


demo.launch()