removed redundancies
Browse files
app.py
CHANGED
@@ -90,16 +90,14 @@ def plot_attention(image, layer_num):
|
|
90 |
attn = attn_call.locals['attn'][0]
|
91 |
with torch.inference_mode():
|
92 |
# loop over attention heads
|
93 |
-
attention_block_num = 0
|
94 |
for h_weights in attn:
|
95 |
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
96 |
h_weights = h_weights[1:] # skip the [class] token
|
97 |
-
|
98 |
-
output_img = plot_weights(input_data, h_weights, attention_block_num)
|
99 |
attention_map_outputs.append(output_img)
|
100 |
return attention_map_outputs
|
101 |
|
102 |
-
def plot_weights(input_data, patch_weights
|
103 |
"""Display the image: Brighter the patch, higher is the attention."""
|
104 |
# multiply each patch of the input image by the corresponding weight
|
105 |
plot = inv_normalize(input_data.clone())
|
@@ -111,7 +109,7 @@ def plot_weights(input_data, patch_weights, num_attention_block):
|
|
111 |
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
112 |
return attn_map_img
|
113 |
|
114 |
-
attention_maps = plot_attention(random_image, 6)
|
115 |
|
116 |
title_classify = "Image Based Plant Disease Identification 🍃🤓"
|
117 |
|
@@ -152,5 +150,7 @@ attention_interface = gr.Interface(
|
|
152 |
)
|
153 |
|
154 |
demo = gr.TabbedInterface([classify_interface, attention_interface],
|
155 |
-
["Identify Disease", "Visualize Attention Map"],
|
156 |
-
|
|
|
|
|
|
90 |
attn = attn_call.locals['attn'][0]
|
91 |
with torch.inference_mode():
|
92 |
# loop over attention heads
|
|
|
93 |
for h_weights in attn:
|
94 |
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
95 |
h_weights = h_weights[1:] # skip the [class] token
|
96 |
+
output_img = plot_weights(input_data, h_weights)
|
|
|
97 |
attention_map_outputs.append(output_img)
|
98 |
return attention_map_outputs
|
99 |
|
100 |
+
def plot_weights(input_data, patch_weights):
|
101 |
"""Display the image: Brighter the patch, higher is the attention."""
|
102 |
# multiply each patch of the input image by the corresponding weight
|
103 |
plot = inv_normalize(input_data.clone())
|
|
|
109 |
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
110 |
return attn_map_img
|
111 |
|
112 |
+
attention_maps = plot_attention(random_image, "6")
|
113 |
|
114 |
title_classify = "Image Based Plant Disease Identification 🍃🤓"
|
115 |
|
|
|
150 |
)
|
151 |
|
152 |
demo = gr.TabbedInterface([classify_interface, attention_interface],
|
153 |
+
["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺")
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
demo.launch()
|