import os import PIL import ast import cv2 import json import torch import pickle import torchvision import numpy as np import gradio as gr from PIL import Image from typing import Tuple, Dict import matplotlib.pyplot as plt from timeit import default_timer as timer from torchvision import datasets, transforms import warnings warnings.filterwarnings('ignore') example_list = [["examples/" + example] for example in os.listdir("examples")] with open('labels/imagenet1k-simple-labels.json') as f: class_names = json.load(f) from model import VisionTransformer from capture_weights import vit_weights vision_transformer = VisionTransformer.from_name('ViT-B_16', num_classes=1000) model_weights = torch.load('pretrained_weights/ViT-B_16_imagenet21k_imagenet2012.pth', map_location=torch.device('cpu')) vision_transformer.load_state_dict(model_weights) data_transforms = transforms.Compose([ transforms.Resize(size=(384, 384)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],)]) def inv_normalize(tensor): """Normalize an image tensor back to the 0-255 range.""" tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5) return tensor def inv_transform(tensor, normalize=True): """Convert a tensor back to an image.""" tensor = inv_normalize(tensor) array = tensor.detach().cpu().numpy() array = array.transpose(1, 2, 0).astype(np.uint8) return PIL.Image.fromarray(array) def predict_image(image) -> Tuple[Dict, float]: """Return prediction classes with probabilities for an input image.""" input_tensor = data_transforms(image) start_time = timer() prediction_dict = {} with torch.inference_mode(): [logits] = vision_transformer(input_tensor[None]) probs = torch.softmax(logits, dim=0) topk_prob, topk_id = torch.topk(probs, 3) for i in range(topk_prob.size(0)): prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item() prediction_time = round(timer() - start_time, 5) return prediction_dict, prediction_time def get_attention_map(img, num_layer=5, get_mask=False): x = data_transforms(img) logits, att_mat = vit_weights(x.unsqueeze(0)) att_mat = torch.stack(att_mat).squeeze(1) # Take the mean of the attention weights across 12 heads att_mat = torch.mean(att_mat, dim=1) # To account for residual connections, we add an identity matrix to the # attention matrix and re-normalize the weights. residual_att = torch.eye(att_mat.size(1)) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) # Recursively multiply the weight matrices joint_attentions = torch.zeros(aug_att_mat.size()) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) v = joint_attentions[num_layer] grid_size = int(np.sqrt(aug_att_mat.size(-1))) mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy() if get_mask: attn_map = cv2.resize(mask / mask.max(), img.size) else: mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis] attn_map = (mask * img).astype("uint8") return attn_map attention_interface = gr.Interface( fn=get_attention_map, inputs=[gr.Image(type="pil", label="Image"), gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], label="Attention Layer", value="6", type="index"), gr.Checkbox(label="Show Mask?")], outputs=gr.Image(type="pil", label="Attention Map").style(height=400), examples=example_list, title="Attention Maps 🔍", description="The ViT Base architecture has 12 transformer Encoder layers (12 attention heads in each).", article="From the dropdown menu, select the Encoder layer (tick the checkbox to visualize only the mask)." ) classification_interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Image"), outputs=[gr.Label(num_top_classes=3, label="Predictions"), gr.Number(label="Prediction time (secs)")], examples=example_list, title="Object Identification ✅", description="ImageNet object identification using pretrained ViT Base (Patch Size: 16 | Image Size: 384) architecture.", article="Upload an image from the example list or choose one of your own [[ImageNet Classes](https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json)]." ) demo = gr.TabbedInterface([attention_interface, classification_interface], ["Visualize Attention Maps", "Image Prediction"], title="ImageNet 1K 📷") if __name__ == "__main__": demo.launch()