Spaces:
Sleeping
Sleeping
import os | |
import PIL | |
import ast | |
import cv2 | |
import json | |
import torch | |
import pickle | |
import torchvision | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from typing import Tuple, Dict | |
import matplotlib.pyplot as plt | |
from timeit import default_timer as timer | |
from torchvision import datasets, transforms | |
import warnings | |
warnings.filterwarnings('ignore') | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
with open('labels/imagenet1k-simple-labels.json') as f: | |
class_names = json.load(f) | |
from model import VisionTransformer | |
from capture_weights import vit_weights | |
vision_transformer = VisionTransformer.from_name('ViT-B_16', num_classes=1000) | |
model_weights = torch.load('pretrained_weights/ViT-B_16_imagenet21k_imagenet2012.pth', | |
map_location=torch.device('cpu')) | |
vision_transformer.load_state_dict(model_weights) | |
data_transforms = transforms.Compose([ | |
transforms.Resize(size=(384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225],)]) | |
def inv_normalize(tensor): | |
"""Normalize an image tensor back to the 0-255 range.""" | |
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5) | |
return tensor | |
def inv_transform(tensor, normalize=True): | |
"""Convert a tensor back to an image.""" | |
tensor = inv_normalize(tensor) | |
array = tensor.detach().cpu().numpy() | |
array = array.transpose(1, 2, 0).astype(np.uint8) | |
return PIL.Image.fromarray(array) | |
def predict_image(image) -> Tuple[Dict, float]: | |
"""Return prediction classes with probabilities for an input image.""" | |
input_tensor = data_transforms(image) | |
start_time = timer() | |
prediction_dict = {} | |
with torch.inference_mode(): | |
[logits] = vision_transformer(input_tensor[None]) | |
probs = torch.softmax(logits, dim=0) | |
topk_prob, topk_id = torch.topk(probs, 3) | |
for i in range(topk_prob.size(0)): | |
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item() | |
prediction_time = round(timer() - start_time, 5) | |
return prediction_dict, prediction_time | |
def get_attention_map(img, num_layer=5, get_mask=False): | |
x = data_transforms(img) | |
logits, att_mat = vit_weights(x.unsqueeze(0)) | |
att_mat = torch.stack(att_mat).squeeze(1) | |
# Take the mean of the attention weights across 12 heads | |
att_mat = torch.mean(att_mat, dim=1) | |
# To account for residual connections, we add an identity matrix to the | |
# attention matrix and re-normalize the weights. | |
residual_att = torch.eye(att_mat.size(1)) | |
aug_att_mat = att_mat + residual_att | |
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) | |
# Recursively multiply the weight matrices | |
joint_attentions = torch.zeros(aug_att_mat.size()) | |
joint_attentions[0] = aug_att_mat[0] | |
for n in range(1, aug_att_mat.size(0)): | |
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) | |
v = joint_attentions[num_layer] | |
grid_size = int(np.sqrt(aug_att_mat.size(-1))) | |
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy() | |
if get_mask: | |
attn_map = cv2.resize(mask / mask.max(), img.size) | |
else: | |
mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis] | |
attn_map = (mask * img).astype("uint8") | |
return attn_map | |
attention_interface = gr.Interface( | |
fn=get_attention_map, | |
inputs=[gr.Image(type="pil", label="Image"), | |
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], | |
label="Attention Layer", value="6", type="index"), | |
gr.Checkbox(label="Show Mask?")], | |
outputs=gr.Image(type="pil", label="Attention Map").style(height=400), | |
examples=example_list, | |
title="Attention Maps π", | |
description="The ViT Base architecture has 12 transformer Encoder layers (12 attention heads in each).", | |
article="From the dropdown menu, select the Encoder layer (tick the checkbox to visualize only the mask)." | |
) | |
classification_interface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil", label="Image"), | |
outputs=[gr.Label(num_top_classes=3, label="Predictions"), | |
gr.Number(label="Prediction time (secs)")], | |
examples=example_list, | |
title="Object Identification β ", | |
description="ImageNet object identification using pretrained ViT Base (Patch Size: 16 | Image Size: 384) architecture.", | |
article="Upload an image from the example list or choose one of your own [[ImageNet Classes](https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json)]." | |
) | |
demo = gr.TabbedInterface([attention_interface, classification_interface], | |
["Visualize Attention Maps", "Image Prediction"], title="ImageNet 1K π·") | |
if __name__ == "__main__": | |
demo.launch() | |