sneha commited on
Commit
5ded884
1 Parent(s): 443912c

add radio buttons

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. attn_helper.py +9 -3
app.py CHANGED
@@ -43,7 +43,7 @@ def download_bin():
43
  os.rename(model_bin, bin_path)
44
 
45
 
46
- def run_attn(input_img):
47
  download_bin()
48
  model, embedding_dim, transform, metadata = get_model()
49
  if input_img.shape[0] != 3:
@@ -55,7 +55,7 @@ def run_attn(input_img):
55
  input_img = resize_transform(input_img)
56
  x = transform(input_img)
57
 
58
- attention_rollout = VITAttentionGradRollout(model,head_fusion="mean")
59
 
60
  y = model(x)
61
  mask = attention_rollout.get_attn_mask()
@@ -69,10 +69,11 @@ def run_attn(input_img):
69
  return attn_img, fig
70
 
71
  input_img = gr.Image(shape=(250,250))
 
72
  output_img = gr.Image(shape=(250,250))
73
  output_plot = gr.Plot()
74
 
75
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model",
76
- examples=[os.path.join('./imgs',x) for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
77
- inputs=input_img,outputs=[output_img,output_plot])
78
  demo.launch()
 
43
  os.rename(model_bin, bin_path)
44
 
45
 
46
+ def run_attn(input_img,fusion):
47
  download_bin()
48
  model, embedding_dim, transform, metadata = get_model()
49
  if input_img.shape[0] != 3:
 
55
  input_img = resize_transform(input_img)
56
  x = transform(input_img)
57
 
58
+ attention_rollout = VITAttentionGradRollout(model,head_fusion=fusion)
59
 
60
  y = model(x)
61
  mask = attention_rollout.get_attn_mask()
 
69
  return attn_img, fig
70
 
71
  input_img = gr.Image(shape=(250,250))
72
+ input_button = gr.Radio(["min", "max", "mean"], label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
73
  output_img = gr.Image(shape=(250,250))
74
  output_plot = gr.Plot()
75
 
76
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model",
77
+ examples=[[os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
78
+ inputs=[input_img,input_button],outputs=[output_img,output_plot])
79
  demo.launch()
attn_helper.py CHANGED
@@ -9,7 +9,7 @@ def overlay_attn(original_image,mask):
9
  # Colormap and alpha for attention mask
10
  # COLORMAP_OCEAN
11
  # COLORMAP_OCEAN
12
- colormap_attn, alpha_attn = cv2.COLORMAP_OCEAN, 1 #0.85
13
 
14
  # Resize mask to original image size
15
  w, h = original_image.shape[0], original_image.shape[1]
@@ -20,9 +20,14 @@ def overlay_attn(original_image,mask):
20
 
21
  print(cmap.shape)
22
  # Blend mask and original image
23
- grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
- alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
 
 
 
 
25
  # alpha_blended = cmap
 
26
 
27
  # Save image
28
  final_im = Image.fromarray(alpha_blended)
@@ -34,6 +39,7 @@ def overlay_attn(original_image,mask):
34
  class VITAttentionGradRollout:
35
  '''
36
  Expects timm ViT transformer model
 
37
  '''
38
  def __init__(self, model, head_fusion='min', discard_ratio=0):
39
  self.model = model
 
9
  # Colormap and alpha for attention mask
10
  # COLORMAP_OCEAN
11
  # COLORMAP_OCEAN
12
+ colormap_attn, alpha_attn = cv2.COLORMAP_JET, 1 #0.85
13
 
14
  # Resize mask to original image size
15
  w, h = original_image.shape[0], original_image.shape[1]
 
20
 
21
  print(cmap.shape)
22
  # Blend mask and original image
23
+ # grayscale_img = cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
24
+ # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
25
+ # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
26
+ alpha_blended = cv2.addWeighted(np.uint8(original_image),0.1, cmap, 0.9, 0)
27
+
28
+
29
  # alpha_blended = cmap
30
+
31
 
32
  # Save image
33
  final_im = Image.fromarray(alpha_blended)
 
39
  class VITAttentionGradRollout:
40
  '''
41
  Expects timm ViT transformer model
42
+ Adapted from https://github.com/samiraabnar/attention_flow
43
  '''
44
  def __init__(self, model, head_fusion='min', discard_ratio=0):
45
  self.model = model