AttentionMaps / app.py
TexR6's picture
initial commit
d7b0f75
raw
history blame contribute delete
No virus
4.93 kB
import os
import PIL
import ast
import cv2
import json
import torch
import pickle
import torchvision
import numpy as np
import gradio as gr
from PIL import Image
from typing import Tuple, Dict
import matplotlib.pyplot as plt
from timeit import default_timer as timer
from torchvision import datasets, transforms
import warnings
warnings.filterwarnings('ignore')
example_list = [["examples/" + example] for example in os.listdir("examples")]
with open('labels/imagenet1k-simple-labels.json') as f:
class_names = json.load(f)
from model import VisionTransformer
from capture_weights import vit_weights
vision_transformer = VisionTransformer.from_name('ViT-B_16', num_classes=1000)
model_weights = torch.load('pretrained_weights/ViT-B_16_imagenet21k_imagenet2012.pth',
map_location=torch.device('cpu'))
vision_transformer.load_state_dict(model_weights)
data_transforms = transforms.Compose([
transforms.Resize(size=(384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],)])
def inv_normalize(tensor):
"""Normalize an image tensor back to the 0-255 range."""
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
return tensor
def inv_transform(tensor, normalize=True):
"""Convert a tensor back to an image."""
tensor = inv_normalize(tensor)
array = tensor.detach().cpu().numpy()
array = array.transpose(1, 2, 0).astype(np.uint8)
return PIL.Image.fromarray(array)
def predict_image(image) -> Tuple[Dict, float]:
"""Return prediction classes with probabilities for an input image."""
input_tensor = data_transforms(image)
start_time = timer()
prediction_dict = {}
with torch.inference_mode():
[logits] = vision_transformer(input_tensor[None])
probs = torch.softmax(logits, dim=0)
topk_prob, topk_id = torch.topk(probs, 3)
for i in range(topk_prob.size(0)):
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
prediction_time = round(timer() - start_time, 5)
return prediction_dict, prediction_time
def get_attention_map(img, num_layer=5, get_mask=False):
x = data_transforms(img)
logits, att_mat = vit_weights(x.unsqueeze(0))
att_mat = torch.stack(att_mat).squeeze(1)
# Take the mean of the attention weights across 12 heads
att_mat = torch.mean(att_mat, dim=1)
# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
v = joint_attentions[num_layer]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
if get_mask:
attn_map = cv2.resize(mask / mask.max(), img.size)
else:
mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
attn_map = (mask * img).astype("uint8")
return attn_map
attention_interface = gr.Interface(
fn=get_attention_map,
inputs=[gr.Image(type="pil", label="Image"),
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
label="Attention Layer", value="6", type="index"),
gr.Checkbox(label="Show Mask?")],
outputs=gr.Image(type="pil", label="Attention Map").style(height=400),
examples=example_list,
title="Attention Maps πŸ”",
description="The ViT Base architecture has 12 transformer Encoder layers (12 attention heads in each).",
article="From the dropdown menu, select the Encoder layer (tick the checkbox to visualize only the mask)."
)
classification_interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Image"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (secs)")],
examples=example_list,
title="Object Identification βœ…",
description="ImageNet object identification using pretrained ViT Base (Patch Size: 16 | Image Size: 384) architecture.",
article="Upload an image from the example list or choose one of your own [[ImageNet Classes](https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json)]."
)
demo = gr.TabbedInterface([attention_interface, classification_interface],
["Visualize Attention Maps", "Image Prediction"], title="ImageNet 1K πŸ“·")
if __name__ == "__main__":
demo.launch()