File size: 3,665 Bytes
aa86478
 
 
 
 
 
 
 
 
 
 
 
38d62ce
 
aa86478
8b10fec
aa86478
 
 
 
c54235c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa86478
c54235c
aa86478
c54235c
 
 
 
 
 
 
 
46f48ca
aa86478
c54235c
cad7903
c54235c
 
 
 
cad7903
 
46f48ca
 
aa86478
 
5161efd
aa86478
 
 
30ae246
c54235c
 
aa86478
 
 
 
 
 
 
 
 
30ae246
aa86478
 
 
 
46f48ca
aa86478
c54235c
ba48df9
aa86478
30ae246
 
aa86478
46f48ca
 
 
30ae246
 
def57e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import omegaconf
from hydra import utils
import os
import torch
import matplotlib.pyplot as plt
from attn_helper import VITAttentionGradRollout, overlay_attn
import vc_models
import torchvision


HF_TOKEN = os.environ['HF_ACC_TOKEN']
eai_filepath = vc_models.__file__.split('src')[0]
MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
if not os.path.isdir(MODEL_DIR):
    os.mkdir(MODEL_DIR)

FILENAME = "config.yaml"
BASE_MODEL_TUPLE = None
LARGE_MODEL_TUPLE = None
def get_model(model_name):
    global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
    download_bin(model_name)
    model = None
    if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
        repo_name = "facebook/" + model_name
        model_cfg = omegaconf.OmegaConf.load(
            hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
        )
        BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
        BASE_MODEL_TUPLE[0].eval()
        model =  BASE_MODEL_TUPLE
    elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
        repo_name = "facebook/" + model_name
        model_cfg = omegaconf.OmegaConf.load(
            hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
        )
        LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
        LARGE_MODEL_TUPLE[0].eval()
        model =  LARGE_MODEL_TUPLE
    elif model_name == 'vc1-base':
        model = BASE_MODEL_TUPLE
    elif model_name == 'vc1-large':
        model = LARGE_MODEL_TUPLE

    return model

def download_bin(model):
    bin_file = ""
    if model == "vc1-large":
        bin_file = 'vc1_vitl.pth'
    elif model == "vc1-base":
        bin_file = 'vc1_vitb.pth'
    else:
        raise NameError("model not found: " + model)
    
    repo_name = 'facebook/' + model
    bin_path = os.path.join(MODEL_DIR,bin_file)
    if not os.path.isfile(bin_path):
        model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
        os.rename(model_bin, bin_path)


def run_attn(input_img, model="vc1-base",discard_ratio=0.89):
    download_bin(model)
    model, embedding_dim, transform, metadata = get_model(model)
    if input_img.shape[0] != 3:
        input_img = input_img.transpose(2, 0, 1)
    if(len(input_img.shape)== 3):
        input_img = torch.tensor(input_img).unsqueeze(0)
    input_img = input_img.float()
    resize_transform = torchvision.transforms.Resize((250,250))
    input_img = resize_transform(input_img)
    x = transform(input_img)

    attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio)
    
    y = model(x)
    mask = attention_rollout.get_attn_mask()
    attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
    return attn_img

model_type = gr.Dropdown(
            ["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
input_img = gr.Image(shape=(250,250))
discard_ratio = gr.Slider(0,1,value=0.89)

output_img = gr.Image(shape=(250,250))
css = "#component-2, .input-image, .image-preview {height: 240px !important}"
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention(green) of the last layer of the transformer."
demo = gr.Interface(fn=run_attn, title="Visual Cortex Model", description=markdown,
                    examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
                    inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css)
demo.launch()