File size: 4,933 Bytes
d7b0f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import PIL
import ast
import cv2
import json
import torch
import pickle
import torchvision
import numpy as np
import gradio as gr
from PIL import Image
from typing import Tuple, Dict
import matplotlib.pyplot as plt
from timeit import default_timer as timer
from torchvision import datasets, transforms

import warnings
warnings.filterwarnings('ignore')

example_list = [["examples/" + example] for example in os.listdir("examples")]

with open('labels/imagenet1k-simple-labels.json') as f:
    class_names = json.load(f)

from model import VisionTransformer
from capture_weights import vit_weights

vision_transformer = VisionTransformer.from_name('ViT-B_16', num_classes=1000)
model_weights = torch.load('pretrained_weights/ViT-B_16_imagenet21k_imagenet2012.pth',
                           map_location=torch.device('cpu'))
vision_transformer.load_state_dict(model_weights)

data_transforms = transforms.Compose([
    transforms.Resize(size=(384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225],)])

def inv_normalize(tensor):
    """Normalize an image tensor back to the 0-255 range."""
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
    return tensor

def inv_transform(tensor, normalize=True):
    """Convert a tensor back to an image."""
    tensor = inv_normalize(tensor)
    array = tensor.detach().cpu().numpy()
    array = array.transpose(1, 2, 0).astype(np.uint8)
    return PIL.Image.fromarray(array)

def predict_image(image) -> Tuple[Dict, float]:
    """Return prediction classes with probabilities for an input image."""
    input_tensor = data_transforms(image)
    start_time = timer()
    prediction_dict = {}
    with torch.inference_mode():
        [logits] = vision_transformer(input_tensor[None])
        probs = torch.softmax(logits, dim=0)
        topk_prob, topk_id = torch.topk(probs, 3)
        for i in range(topk_prob.size(0)):
            prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
    prediction_time = round(timer() - start_time, 5)
    return prediction_dict, prediction_time

def get_attention_map(img, num_layer=5, get_mask=False):
    x = data_transforms(img)
    logits, att_mat = vit_weights(x.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)
    # Take the mean of the attention weights across 12 heads
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[num_layer]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        attn_map = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        attn_map = (mask * img).astype("uint8")
    return attn_map

attention_interface = gr.Interface(
    fn=get_attention_map,
    inputs=[gr.Image(type="pil", label="Image"),
            gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
                        label="Attention Layer", value="6", type="index"),
            gr.Checkbox(label="Show Mask?")],
    outputs=gr.Image(type="pil", label="Attention Map").style(height=400),
    examples=example_list,
    title="Attention Maps πŸ”",
    description="The ViT Base architecture has 12 transformer Encoder layers (12 attention heads in each).",
    article="From the dropdown menu, select the Encoder layer (tick the checkbox to visualize only the mask)."
)

classification_interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Image"),
    outputs=[gr.Label(num_top_classes=3, label="Predictions"),
              gr.Number(label="Prediction time (secs)")],
    examples=example_list,
    title="Object Identification βœ…",
    description="ImageNet object identification using pretrained ViT Base (Patch Size: 16 | Image Size: 384) architecture.",
    article="Upload an image from the example list or choose one of your own [[ImageNet Classes](https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json)]."
)

demo = gr.TabbedInterface([attention_interface, classification_interface],
                          ["Visualize Attention Maps", "Image Prediction"], title="ImageNet 1K πŸ“·")

if __name__ == "__main__":
    demo.launch()