File size: 3,665 Bytes
aa86478 38d62ce aa86478 8b10fec aa86478 c54235c aa86478 c54235c aa86478 c54235c 46f48ca aa86478 c54235c cad7903 c54235c cad7903 46f48ca aa86478 5161efd aa86478 30ae246 c54235c aa86478 30ae246 aa86478 46f48ca aa86478 c54235c ba48df9 aa86478 30ae246 aa86478 46f48ca 30ae246 def57e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import omegaconf
from hydra import utils
import os
import torch
import matplotlib.pyplot as plt
from attn_helper import VITAttentionGradRollout, overlay_attn
import vc_models
import torchvision
HF_TOKEN = os.environ['HF_ACC_TOKEN']
eai_filepath = vc_models.__file__.split('src')[0]
MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
if not os.path.isdir(MODEL_DIR):
os.mkdir(MODEL_DIR)
FILENAME = "config.yaml"
BASE_MODEL_TUPLE = None
LARGE_MODEL_TUPLE = None
def get_model(model_name):
global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
download_bin(model_name)
model = None
if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
repo_name = "facebook/" + model_name
model_cfg = omegaconf.OmegaConf.load(
hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
)
BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
BASE_MODEL_TUPLE[0].eval()
model = BASE_MODEL_TUPLE
elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
repo_name = "facebook/" + model_name
model_cfg = omegaconf.OmegaConf.load(
hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
)
LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
LARGE_MODEL_TUPLE[0].eval()
model = LARGE_MODEL_TUPLE
elif model_name == 'vc1-base':
model = BASE_MODEL_TUPLE
elif model_name == 'vc1-large':
model = LARGE_MODEL_TUPLE
return model
def download_bin(model):
bin_file = ""
if model == "vc1-large":
bin_file = 'vc1_vitl.pth'
elif model == "vc1-base":
bin_file = 'vc1_vitb.pth'
else:
raise NameError("model not found: " + model)
repo_name = 'facebook/' + model
bin_path = os.path.join(MODEL_DIR,bin_file)
if not os.path.isfile(bin_path):
model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
os.rename(model_bin, bin_path)
def run_attn(input_img, model="vc1-base",discard_ratio=0.89):
download_bin(model)
model, embedding_dim, transform, metadata = get_model(model)
if input_img.shape[0] != 3:
input_img = input_img.transpose(2, 0, 1)
if(len(input_img.shape)== 3):
input_img = torch.tensor(input_img).unsqueeze(0)
input_img = input_img.float()
resize_transform = torchvision.transforms.Resize((250,250))
input_img = resize_transform(input_img)
x = transform(input_img)
attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio)
y = model(x)
mask = attention_rollout.get_attn_mask()
attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
return attn_img
model_type = gr.Dropdown(
["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
input_img = gr.Image(shape=(250,250))
discard_ratio = gr.Slider(0,1,value=0.89)
output_img = gr.Image(shape=(250,250))
css = "#component-2, .input-image, .image-preview {height: 240px !important}"
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention(green) of the last layer of the transformer."
demo = gr.Interface(fn=run_attn, title="Visual Cortex Model", description=markdown,
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css)
demo.launch()
|