TexR6 commited on
Commit
b72342a
1 Parent(s): 0c3fb24

fixed attention_layer not receiving input issue.

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -83,7 +83,6 @@ random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RG
83
 
84
  def plot_attention(image, encoder_layer_num=5):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
86
- print(f'Selected encoder layer num: {encoder_layer_num}')
87
  attention_map_outputs = []
88
  input_data = data_transforms(image)
89
  with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
@@ -140,7 +139,8 @@ classify_interface = gr.Interface(
140
  attention_interface = gr.Interface(
141
  fn=plot_attention,
142
  inputs=[gr.Image(type="pil", label="Image"),
143
- gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value="6", type="index")],
 
144
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
145
  examples=example_list,
146
  title=title_attention,
 
83
 
84
  def plot_attention(image, encoder_layer_num=5):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
 
86
  attention_map_outputs = []
87
  input_data = data_transforms(image)
88
  with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
 
139
  attention_interface = gr.Interface(
140
  fn=plot_attention,
141
  inputs=[gr.Image(type="pil", label="Image"),
142
+ gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
143
+ label="Attention Layer", value="6", type="index")],
144
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
145
  examples=example_list,
146
  title=title_attention,