sneha commited on
Commit
aa86478
1 Parent(s): 69678e2

initial commit

Browse files
Files changed (6) hide show
  1. app.py +102 -0
  2. attn_helper.py +107 -0
  3. ego4d.jpg +0 -0
  4. kitchen.jpg +0 -0
  5. rearrange.jpg +0 -0
  6. trifinger.jpg +0 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ import omegaconf
5
+ from hydra import utils
6
+ import os
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ from attn_helper import VITAttentionGradRollout, overlay_attn
10
+ import vc_models
11
+ #import eaif_models
12
+ import torchvision
13
+
14
+ eai_filepath = vc_models.__file__.split('src')[0]
15
+
16
+ MODEL_DIR=os.path.join(eai_filepath, 'src','model_ckpts')
17
+ if not os.path.isdir(MODEL_DIR):
18
+ os.mkdir(MODEL_DIR)
19
+
20
+ REPO_ID = "facebook/vc1-base"
21
+ FILENAME = "config.yaml"
22
+ MODEL_TUPLE = None
23
+
24
+ def get_model():
25
+ global MODEL_TUPLE
26
+ download_bin()
27
+ if MODEL_TUPLE is None:
28
+ model_cfg = omegaconf.OmegaConf.load(
29
+ hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
30
+ )
31
+ model_cfg['model']['checkpoint_path'] = None
32
+ model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'#os.path.join(os.getcwd(),'pytorch_model.bin')
33
+ print(model_cfg)
34
+ MODEL_TUPLE = utils.instantiate(model_cfg)
35
+ MODEL_TUPLE[0].eval()
36
+ return MODEL_TUPLE#model,embedding_dim,transform,metadata
37
+
38
+ def download_bin():
39
+
40
+ bin_file = 'vc1_vitb.pth' #'pytorch_model.bin'
41
+ bin_path = os.path.join(MODEL_DIR,bin_file)
42
+ print(bin_path)
43
+ if not os.path.isfile(bin_path):
44
+ #with open(bin_file,'w') as f:
45
+ model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True)
46
+ os.rename(model_bin, bin_path)
47
+ print(type(model_bin))
48
+ # os.rename(model_bin, bin_file)
49
+ # f.write(model_bin)
50
+
51
+
52
+ def run_attn(input_img):
53
+ download_bin()
54
+ model, embedding_dim, transform, metadata = get_model()
55
+ print(input_img.shape)
56
+ if input_img.shape[0] != 3:
57
+ input_img = input_img.transpose(2, 0, 1)
58
+ print(input_img.shape)
59
+ if(len(input_img.shape)== 3):
60
+ input_img = torch.tensor(input_img).unsqueeze(0)
61
+ input_img = input_img.float()
62
+ resize_transform = torchvision.transforms.Resize((250,250))
63
+ input_img = resize_transform(input_img)
64
+ x = transform(input_img)
65
+ #y = x /x.max() * 255
66
+ #y = y[0].int().permute(1,2,0).numpy()
67
+
68
+ attention_rollout = VITAttentionGradRollout(model,head_fusion="mean")
69
+
70
+ y = model(x)
71
+ mask = attention_rollout.get_attn_mask()
72
+ print(input_img.shape)
73
+ print(mask.shape)
74
+ attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
75
+
76
+ fig = plt.figure()
77
+ ax = fig.subplots()
78
+ im = ax.matshow(y.detach().numpy().reshape(16,-1))
79
+ plt.colorbar(im)
80
+
81
+ return attn_img, fig
82
+
83
+ # with gr.Blocks() as demo:
84
+ # gr.Markdown("Visual Cortex Base Model")
85
+ # input_img = gr.Image(shape=(250,250))
86
+ # output_img = gr.Image(shape=(250,250))
87
+ # output_plot = gr.Plot()
88
+ # btn = gr.Button("Encode Representation")
89
+ # gr.Examples(["./trifinger.jpg","./rearrange.jpg","./kitchen.jpg","./ego4d.jpg"],input_img)
90
+ # # demo = gr.Interface(fn=run_attn, inputs=gr.Image(shape=(250,250)), title="Visual Cortex Base Model",
91
+ # btn.click(fn=run_attn, inputs=input_img,
92
+ # outputs=[output_img,output_plot])
93
+
94
+
95
+ input_img = gr.Image(shape=(250,250))
96
+ output_img = gr.Image(shape=(250,250))
97
+ output_plot = gr.Plot()
98
+
99
+ demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model",
100
+ examples=["./trifinger.jpg","./rearrange.jpg","./kitchen.jpg","./ego4d.jpg"],
101
+ inputs=input_img,outputs=[output_img,output_plot])
102
+ demo.launch(share=True)
attn_helper.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+ import PIL
7
+
8
+ def overlay_attn(original_image,mask):
9
+ # Colormap and alpha for attention mask
10
+ # COLORMAP_OCEAN
11
+ # COLORMAP_OCEAN
12
+ colormap_attn, alpha_attn = cv2.COLORMAP_OCEAN, 1 #0.85
13
+
14
+ # Resize mask to original image size
15
+ w, h = original_image.shape[0], original_image.shape[1]
16
+ mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
17
+
18
+ # Apply colormap to mask
19
+ cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
+
21
+ print(cmap.shape)
22
+ # Blend mask and original image
23
+ grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
+ alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
25
+ # alpha_blended = cmap
26
+
27
+ # Save image
28
+ final_im = Image.fromarray(alpha_blended)
29
+ # final_im = final_im.crop((0,0,250,250))
30
+ return final_im
31
+
32
+
33
+
34
+ class VITAttentionGradRollout:
35
+ '''
36
+ Expects timm ViT transformer model
37
+ '''
38
+ def __init__(self, model, head_fusion='min', discard_ratio=0):
39
+ self.model = model
40
+ self.head_fusion = head_fusion
41
+ self.discard_ratio = discard_ratio
42
+ print(list(model.blocks.children()))
43
+
44
+ self.attentions = {}
45
+ for idx, module in enumerate(list(model.blocks.children())):
46
+ module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))
47
+
48
+
49
+ def get_attention(self, name):
50
+ def hook(module, input, output):
51
+ with torch.no_grad():
52
+ input = input[0]
53
+ B, N, C = input.shape
54
+ qkv = (
55
+ module.qkv(input)
56
+ .detach()
57
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
58
+ .permute(2, 0, 3, 1, 4)
59
+ )
60
+ q, k, _ = (
61
+ qkv[0],
62
+ qkv[1],
63
+ qkv[2],
64
+ ) # make torchscript happy (cannot use tensor as tuple)
65
+ attn = (q @ k.transpose(-2, -1)) * module.scale
66
+ attn = attn.softmax(dim=-1)
67
+ self.attentions[name] = attn
68
+ return hook
69
+
70
+ def get_attn_mask(self,k=0):
71
+ attn_key = "attn" + str()
72
+ result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)
73
+
74
+ # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
75
+ with torch.no_grad():
76
+ # for attention in self.attentions.values():
77
+ for k in range(11, len(self.attentions.keys())):
78
+ attention = self.attentions[f'attn{k}']
79
+ if self.head_fusion == "mean":
80
+ attention_heads_fused = attention.mean(axis=1)
81
+ elif self.head_fusion == "max":
82
+ attention_heads_fused = attention.max(axis=1)[0]
83
+ elif self.head_fusion == "min":
84
+ attention_heads_fused = attention.min(axis=1)[0]
85
+ else:
86
+ raise "Attention head fusion type Not supported"
87
+
88
+ # Drop the lowest attentions, but
89
+ # don't drop the class token
90
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
91
+ _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
92
+ indices = indices[indices != 0]
93
+ flat[0, indices] = 0
94
+ I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
95
+ a = (attention_heads_fused + 1.0*I)/2
96
+ a = a / a.sum(dim=-1).unsqueeze(-1)
97
+
98
+ result = torch.matmul(a, result)
99
+
100
+ # Look at the total attention between the class token,
101
+ # and the image patches
102
+ mask = result[0, 0 , 1 :]
103
+ # In case of 224x224 image, this brings us from 196 to 14
104
+ width = int(mask.size(-1)**0.5)
105
+ mask = mask.reshape(width, width).detach().cpu().numpy()
106
+ mask = mask / np.max(mask)
107
+ return mask
ego4d.jpg ADDED
kitchen.jpg ADDED
rearrange.jpg ADDED
trifinger.jpg ADDED