Spaces:
Running
Running
| import os | |
| import PIL | |
| import ast | |
| import cv2 | |
| import json | |
| import torch | |
| import pickle | |
| import torchvision | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from typing import Tuple, Dict | |
| import matplotlib.pyplot as plt | |
| from timeit import default_timer as timer | |
| from torchvision import datasets, transforms | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| with open('labels/imagenet1k-simple-labels.json') as f: | |
| class_names = json.load(f) | |
| from model import VisionTransformer | |
| from capture_weights import vit_weights | |
| vision_transformer = VisionTransformer.from_name('ViT-B_16', num_classes=1000) | |
| model_weights = torch.load('pretrained_weights/ViT-B_16_imagenet21k_imagenet2012.pth', | |
| map_location=torch.device('cpu')) | |
| vision_transformer.load_state_dict(model_weights) | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(size=(384, 384)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225],)]) | |
| def inv_normalize(tensor): | |
| """Normalize an image tensor back to the 0-255 range.""" | |
| tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5) | |
| return tensor | |
| def inv_transform(tensor, normalize=True): | |
| """Convert a tensor back to an image.""" | |
| tensor = inv_normalize(tensor) | |
| array = tensor.detach().cpu().numpy() | |
| array = array.transpose(1, 2, 0).astype(np.uint8) | |
| return PIL.Image.fromarray(array) | |
| def predict_image(image) -> Tuple[Dict, float]: | |
| """Return prediction classes with probabilities for an input image.""" | |
| input_tensor = data_transforms(image) | |
| start_time = timer() | |
| prediction_dict = {} | |
| with torch.inference_mode(): | |
| [logits] = vision_transformer(input_tensor[None]) | |
| probs = torch.softmax(logits, dim=0) | |
| topk_prob, topk_id = torch.topk(probs, 3) | |
| for i in range(topk_prob.size(0)): | |
| prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item() | |
| prediction_time = round(timer() - start_time, 5) | |
| return prediction_dict, prediction_time | |
| def get_attention_map(img, num_layer=5, get_mask=False): | |
| x = data_transforms(img) | |
| logits, att_mat = vit_weights(x.unsqueeze(0)) | |
| att_mat = torch.stack(att_mat).squeeze(1) | |
| # Take the mean of the attention weights across 12 heads | |
| att_mat = torch.mean(att_mat, dim=1) | |
| # To account for residual connections, we add an identity matrix to the | |
| # attention matrix and re-normalize the weights. | |
| residual_att = torch.eye(att_mat.size(1)) | |
| aug_att_mat = att_mat + residual_att | |
| aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) | |
| # Recursively multiply the weight matrices | |
| joint_attentions = torch.zeros(aug_att_mat.size()) | |
| joint_attentions[0] = aug_att_mat[0] | |
| for n in range(1, aug_att_mat.size(0)): | |
| joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) | |
| v = joint_attentions[num_layer] | |
| grid_size = int(np.sqrt(aug_att_mat.size(-1))) | |
| mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy() | |
| if get_mask: | |
| attn_map = cv2.resize(mask / mask.max(), img.size) | |
| else: | |
| mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis] | |
| attn_map = (mask * img).astype("uint8") | |
| return attn_map | |
| attention_interface = gr.Interface( | |
| fn=get_attention_map, | |
| inputs=[gr.Image(type="pil", label="Image"), | |
| gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], | |
| label="Attention Layer", value="6", type="index"), | |
| gr.Checkbox(label="Show Mask?")], | |
| outputs=gr.Image(type="pil", label="Attention Map").style(height=400), | |
| examples=example_list, | |
| title="Attention Maps π", | |
| description="The ViT Base architecture has 12 transformer Encoder layers (12 attention heads in each).", | |
| article="From the dropdown menu, select the Encoder layer (tick the checkbox to visualize only the mask)." | |
| ) | |
| classification_interface = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil", label="Image"), | |
| outputs=[gr.Label(num_top_classes=3, label="Predictions"), | |
| gr.Number(label="Prediction time (secs)")], | |
| examples=example_list, | |
| title="Object Identification β ", | |
| description="ImageNet object identification using pretrained ViT Base (Patch Size: 16 | Image Size: 384) architecture.", | |
| article="Upload an image from the example list or choose one of your own [[ImageNet Classes](https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json)]." | |
| ) | |
| demo = gr.TabbedInterface([attention_interface, classification_interface], | |
| ["Visualize Attention Maps", "Image Prediction"], title="ImageNet 1K π·") | |
| if __name__ == "__main__": | |
| demo.launch() | |