|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
import PIL |
|
from PIL import Image |
|
|
|
import torch |
|
import torchvision |
|
from torchvision import datasets, transforms |
|
|
|
import vision_transformer as vits |
|
|
|
arch = "vit_small" |
|
mode = "simpool" |
|
gamma = None |
|
patch_size = 16 |
|
num_classes = 0 |
|
checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth" |
|
checkpoint_key = "teacher" |
|
|
|
cm = plt.get_cmap('viridis') |
|
attn_map_size = 224 |
|
width_display = 290 |
|
height_display = 290 |
|
|
|
example_dir = "examples/" |
|
example_list = [[example_dir + example] for example in os.listdir(example_dir)] |
|
|
|
|
|
|
|
model = vits.__dict__[arch]( |
|
mode=mode, |
|
gamma=gamma, |
|
patch_size=patch_size, |
|
num_classes=num_classes, |
|
) |
|
state_dict = torch.load(checkpoint) |
|
state_dict = state_dict[checkpoint_key] |
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} |
|
state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()} |
|
msg = model.load_state_dict(state_dict, strict=True) |
|
|
|
model.eval() |
|
|
|
def get_attention_map(img, resolution=32): |
|
input_size = resolution * 14 |
|
data_transforms = transforms.Compose([ |
|
transforms.Resize((input_size, input_size), interpolation=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
x = data_transforms(img) |
|
attn = model.get_simpool_attention(x[None, :, :, :]) |
|
attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size) |
|
attn = attn/attn.sum() |
|
attn = attn.squeeze() |
|
attn = (attn-(attn).min())/((attn).max()-(attn).min()) |
|
attn = torch.threshold(attn, 0.1, 0) |
|
|
|
attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB') |
|
attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST) |
|
return attn_img |
|
|
|
attention_interface = gr.Interface( |
|
fn=get_attention_map, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Dropdown(choices=[16, 32, 64, 128], |
|
label="Attention Map Resolution", |
|
value=32) |
|
], |
|
outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display), |
|
examples=example_list, |
|
title="Explore the Attention Maps of SimPool🔍", |
|
description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision." |
|
) |
|
|
|
demo = gr.TabbedInterface([attention_interface], |
|
["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |