fuck this shit!
Browse files
app.py
CHANGED
@@ -81,8 +81,9 @@ def predict_tensor(img_tensor):
|
|
81 |
|
82 |
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
83 |
|
84 |
-
def plot_attention(image, encoder_layer_num):
|
85 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
|
|
86 |
attention_map_outputs = []
|
87 |
input_data = data_transforms(image)
|
88 |
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
|
@@ -139,7 +140,7 @@ classify_interface = gr.Interface(
|
|
139 |
attention_interface = gr.Interface(
|
140 |
fn=plot_attention,
|
141 |
inputs=[gr.Image(type="pil", label="Image"),
|
142 |
-
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value=
|
143 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
144 |
examples=example_list,
|
145 |
title=title_attention,
|
|
|
81 |
|
82 |
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
83 |
|
84 |
+
def plot_attention(image, encoder_layer_num=5):
|
85 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
86 |
+
print(f'Selected encoder layer num: {encoder_layer_num}')
|
87 |
attention_map_outputs = []
|
88 |
input_data = data_transforms(image)
|
89 |
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
|
|
|
140 |
attention_interface = gr.Interface(
|
141 |
fn=plot_attention,
|
142 |
inputs=[gr.Image(type="pil", label="Image"),
|
143 |
+
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value="6", type="index")],
|
144 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
145 |
examples=example_list,
|
146 |
title=title_attention,
|