TexR6 commited on
Commit
0c3fb24
1 Parent(s): b694259

fuck this shit!

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -81,8 +81,9 @@ def predict_tensor(img_tensor):
81
 
82
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
83
 
84
- def plot_attention(image, encoder_layer_num):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
 
86
  attention_map_outputs = []
87
  input_data = data_transforms(image)
88
  with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
@@ -139,7 +140,7 @@ classify_interface = gr.Interface(
139
  attention_interface = gr.Interface(
140
  fn=plot_attention,
141
  inputs=[gr.Image(type="pil", label="Image"),
142
- gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value=5, type="index")],
143
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
144
  examples=example_list,
145
  title=title_attention,
 
81
 
82
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
83
 
84
+ def plot_attention(image, encoder_layer_num=5):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
86
+ print(f'Selected encoder layer num: {encoder_layer_num}')
87
  attention_map_outputs = []
88
  input_data = data_transforms(image)
89
  with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
 
140
  attention_interface = gr.Interface(
141
  fn=plot_attention,
142
  inputs=[gr.Image(type="pil", label="Image"),
143
+ gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value="6", type="index")],
144
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
145
  examples=example_list,
146
  title=title_attention,