fixed attention_layer not receiving input issue.
Browse files
app.py
CHANGED
@@ -83,7 +83,6 @@ random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RG
|
|
83 |
|
84 |
def plot_attention(image, encoder_layer_num=5):
|
85 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
86 |
-
print(f'Selected encoder layer num: {encoder_layer_num}')
|
87 |
attention_map_outputs = []
|
88 |
input_data = data_transforms(image)
|
89 |
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
|
@@ -140,7 +139,8 @@ classify_interface = gr.Interface(
|
|
140 |
attention_interface = gr.Interface(
|
141 |
fn=plot_attention,
|
142 |
inputs=[gr.Image(type="pil", label="Image"),
|
143 |
-
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
|
|
144 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
145 |
examples=example_list,
|
146 |
title=title_attention,
|
|
|
83 |
|
84 |
def plot_attention(image, encoder_layer_num=5):
|
85 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
|
|
86 |
attention_map_outputs = []
|
87 |
input_data = data_transforms(image)
|
88 |
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
|
|
|
139 |
attention_interface = gr.Interface(
|
140 |
fn=plot_attention,
|
141 |
inputs=[gr.Image(type="pil", label="Image"),
|
142 |
+
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
143 |
+
label="Attention Layer", value="6", type="index")],
|
144 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
145 |
examples=example_list,
|
146 |
title=title_attention,
|