Update app.py
Browse files
app.py
CHANGED
@@ -87,6 +87,8 @@ random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RG
|
|
87 |
def plot_attention(image, layer_num):
|
88 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
89 |
input_data = data_transforms(image)
|
|
|
|
|
90 |
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
91 |
predict_tensor(img_transformed)
|
92 |
attn = attn_call.locals['attn'][0]
|
@@ -130,7 +132,7 @@ article_attention = """From the dropdown menu, choose the Encoder layer whose at
|
|
130 |
|
131 |
classify_interface = gr.Interface(
|
132 |
fn=predict_disease,
|
133 |
-
inputs=gr.Image(type="pil", label="Image"
|
134 |
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
135 |
gr.Number(label="Prediction time (secs)")],
|
136 |
examples=example_list,
|
@@ -142,8 +144,9 @@ classify_interface = gr.Interface(
|
|
142 |
|
143 |
attention_interface = gr.Interface(
|
144 |
fn=plot_attention,
|
145 |
-
inputs=[gr.Image(type="pil", label="Image"
|
146 |
-
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
|
|
147 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
148 |
examples=example_list,
|
149 |
title=title_attention,
|
|
|
87 |
def plot_attention(image, layer_num):
|
88 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
89 |
input_data = data_transforms(image)
|
90 |
+
print(input_data.shape)
|
91 |
+
print(layer_num)
|
92 |
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
93 |
predict_tensor(img_transformed)
|
94 |
attn = attn_call.locals['attn'][0]
|
|
|
132 |
|
133 |
classify_interface = gr.Interface(
|
134 |
fn=predict_disease,
|
135 |
+
inputs=gr.Image(type="pil", label="Image"),
|
136 |
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
137 |
gr.Number(label="Prediction time (secs)")],
|
138 |
examples=example_list,
|
|
|
144 |
|
145 |
attention_interface = gr.Interface(
|
146 |
fn=plot_attention,
|
147 |
+
inputs=[gr.Image(type="pil", label="Image"),
|
148 |
+
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
149 |
+
value="6", label="Attention Layer")],
|
150 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
151 |
examples=example_list,
|
152 |
title=title_attention,
|