sneha commited on
Commit
46f48ca
1 Parent(s): def57e6

change attn map appearance, simplify

Browse files
Files changed (2) hide show
  1. app.py +11 -29
  2. attn_helper.py +2 -4
app.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  import matplotlib.pyplot as plt
9
  from attn_helper import VITAttentionGradRollout, overlay_attn
10
  import vc_models
11
- #import eaif_models
12
  import torchvision
13
 
14
 
@@ -18,7 +17,6 @@ MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
18
  if not os.path.isdir(MODEL_DIR):
19
  os.mkdir(MODEL_DIR)
20
 
21
-
22
  FILENAME = "config.yaml"
23
  BASE_MODEL_TUPLE = None
24
  LARGE_MODEL_TUPLE = None
@@ -31,8 +29,6 @@ def get_model(model_name):
31
  model_cfg = omegaconf.OmegaConf.load(
32
  hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
33
  )
34
- # model_cfg['model']['checkpoint_path'] = None
35
- # model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'
36
  BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
37
  BASE_MODEL_TUPLE[0].eval()
38
  model = BASE_MODEL_TUPLE
@@ -41,8 +37,6 @@ def get_model(model_name):
41
  model_cfg = omegaconf.OmegaConf.load(
42
  hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
43
  )
44
- # model_cfg['model']['checkpoint_path'] = None
45
- # model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'
46
  LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
47
  LARGE_MODEL_TUPLE[0].eval()
48
  model = LARGE_MODEL_TUPLE
@@ -51,7 +45,7 @@ def get_model(model_name):
51
  elif model_name == 'vc1-large':
52
  model = LARGE_MODEL_TUPLE
53
 
54
- return model #model,embedding_dim,transform,metadata
55
 
56
  def download_bin(model):
57
  bin_file = ""
@@ -61,14 +55,15 @@ def download_bin(model):
61
  bin_file = 'vc1_vitb.pth'
62
  else:
63
  raise NameError("model not found: " + model)
64
- repo_name = 'facebook/' + model
 
65
  bin_path = os.path.join(MODEL_DIR,bin_file)
66
  if not os.path.isfile(bin_path):
67
  model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
68
  os.rename(model_bin, bin_path)
69
 
70
 
71
- def run_attn(input_img, model="vc1-base",fusion="min"):
72
  download_bin(model)
73
  model, embedding_dim, transform, metadata = get_model(model)
74
  if input_img.shape[0] != 3:
@@ -80,33 +75,20 @@ def run_attn(input_img, model="vc1-base",fusion="min"):
80
  input_img = resize_transform(input_img)
81
  x = transform(input_img)
82
 
83
- attention_rollout = VITAttentionGradRollout(model,head_fusion=fusion)
84
 
85
  y = model(x)
86
  mask = attention_rollout.get_attn_mask()
87
  attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
88
-
89
- fig = plt.figure()
90
- ax = fig.subplots()
91
- print(y.shape)
92
- im = ax.matshow(y.detach().numpy().reshape(16,-1))
93
- plt.colorbar(im)
94
-
95
- return attn_img, fig
96
 
97
  model_type = gr.Dropdown(
98
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
99
  input_img = gr.Image(shape=(250,250))
100
- input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
101
  output_img = gr.Image(shape=(250,250))
102
- output_plot = gr.Plot()
103
-
104
- css = ".output-image, .input-image, .image-preview {height: 600px !important}"
105
-
106
- markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
107
- The user can decide how the attention heads will be combined. \
108
- Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 for VC1-Base or 16x64 grid for VC1-Large."
109
- demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
110
- examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
111
- inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot],css=css)
112
  demo.launch()
8
  import matplotlib.pyplot as plt
9
  from attn_helper import VITAttentionGradRollout, overlay_attn
10
  import vc_models
 
11
  import torchvision
12
 
13
 
17
  if not os.path.isdir(MODEL_DIR):
18
  os.mkdir(MODEL_DIR)
19
 
 
20
  FILENAME = "config.yaml"
21
  BASE_MODEL_TUPLE = None
22
  LARGE_MODEL_TUPLE = None
29
  model_cfg = omegaconf.OmegaConf.load(
30
  hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
31
  )
 
 
32
  BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
33
  BASE_MODEL_TUPLE[0].eval()
34
  model = BASE_MODEL_TUPLE
37
  model_cfg = omegaconf.OmegaConf.load(
38
  hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
39
  )
 
 
40
  LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
41
  LARGE_MODEL_TUPLE[0].eval()
42
  model = LARGE_MODEL_TUPLE
45
  elif model_name == 'vc1-large':
46
  model = LARGE_MODEL_TUPLE
47
 
48
+ return model
49
 
50
  def download_bin(model):
51
  bin_file = ""
55
  bin_file = 'vc1_vitb.pth'
56
  else:
57
  raise NameError("model not found: " + model)
58
+
59
+ repo_name = 'facebook/' + model
60
  bin_path = os.path.join(MODEL_DIR,bin_file)
61
  if not os.path.isfile(bin_path):
62
  model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
63
  os.rename(model_bin, bin_path)
64
 
65
 
66
+ def run_attn(input_img, model="vc1-base"):
67
  download_bin(model)
68
  model, embedding_dim, transform, metadata = get_model(model)
69
  if input_img.shape[0] != 3:
75
  input_img = resize_transform(input_img)
76
  x = transform(input_img)
77
 
78
+ attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=0.89)
79
 
80
  y = model(x)
81
  mask = attention_rollout.get_attn_mask()
82
  attn_img = overlay_attn(input_img[0].permute(1,2,0),mask)
83
+ return attn_img
 
 
 
 
 
 
 
84
 
85
  model_type = gr.Dropdown(
86
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
87
  input_img = gr.Image(shape=(250,250))
 
88
  output_img = gr.Image(shape=(250,250))
89
+ css = "#component-2, .input-image, .image-preview {height: 240px !important}"
90
+ markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention(green) of the last layer of the transformer."
91
+ demo = gr.Interface(fn=run_attn, title="Visual Cortex Model", description=markdown,
92
+ examples=[[os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
93
+ inputs=[input_img,model_type],outputs=output_img,css=css)
 
 
 
 
 
94
  demo.launch()
attn_helper.py CHANGED
@@ -9,7 +9,7 @@ def overlay_attn(original_image,mask):
9
  # Colormap and alpha for attention mask
10
  # COLORMAP_OCEAN
11
  # COLORMAP_OCEAN
12
- colormap_attn, alpha_attn = cv2.COLORMAP_JET, 1 #0.85
13
 
14
  # Resize mask to original image size
15
  w, h = original_image.shape[0], original_image.shape[1]
@@ -18,12 +18,11 @@ def overlay_attn(original_image,mask):
18
  # Apply colormap to mask
19
  cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
 
21
- print(cmap.shape)
22
  # Blend mask and original image
23
  # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
  # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
25
  # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
26
- alpha_blended = cv2.addWeighted(np.uint8(original_image),0.1, cmap, 0.9, 0)
27
 
28
 
29
  # alpha_blended = cmap
@@ -45,7 +44,6 @@ class VITAttentionGradRollout:
45
  self.model = model
46
  self.head_fusion = head_fusion
47
  self.discard_ratio = discard_ratio
48
- print(list(model.blocks.children()))
49
 
50
  self.attentions = {}
51
  for idx, module in enumerate(list(model.blocks.children())):
9
  # Colormap and alpha for attention mask
10
  # COLORMAP_OCEAN
11
  # COLORMAP_OCEAN
12
+ colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 #0.85
13
 
14
  # Resize mask to original image size
15
  w, h = original_image.shape[0], original_image.shape[1]
18
  # Apply colormap to mask
19
  cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)
20
 
 
21
  # Blend mask and original image
22
  # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
23
  # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
24
  # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
25
+ alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0)
26
 
27
 
28
  # alpha_blended = cmap
44
  self.model = model
45
  self.head_fusion = head_fusion
46
  self.discard_ratio = discard_ratio
 
47
 
48
  self.attentions = {}
49
  for idx, module in enumerate(list(model.blocks.children())):