fixed file saving issues
Browse files
app.py
CHANGED
@@ -79,16 +79,12 @@ def predict_tensor(img_tensor):
|
|
79 |
probs = torch.softmax(logits, dim=0)
|
80 |
topk_prob, topk_id = torch.topk(probs, 3)
|
81 |
|
82 |
-
with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
|
83 |
-
predict_tensor(img_transformed)
|
84 |
-
|
85 |
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
86 |
|
87 |
def plot_attention(image, layer_num):
|
88 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
|
|
89 |
input_data = data_transforms(image)
|
90 |
-
print(input_data.shape)
|
91 |
-
print(layer_num)
|
92 |
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
93 |
predict_tensor(img_transformed)
|
94 |
attn = attn_call.locals['attn'][0]
|
@@ -99,9 +95,9 @@ def plot_attention(image, layer_num):
|
|
99 |
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
100 |
h_weights = h_weights[1:] # skip the [class] token
|
101 |
attention_block_num += 1
|
102 |
-
plot_weights(input_data, h_weights, attention_block_num)
|
103 |
-
|
104 |
-
return
|
105 |
|
106 |
def plot_weights(input_data, patch_weights, num_attention_block):
|
107 |
"""Display the image: Brighter the patch, higher is the attention."""
|
@@ -113,10 +109,9 @@ def plot_weights(input_data, patch_weights, num_attention_block):
|
|
113 |
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
|
114 |
attn_map_img = inv_transform(plot, normalize=False)
|
115 |
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
116 |
-
attn_map_img
|
117 |
|
118 |
-
|
119 |
-
attention_maps = plot_attention(random_image, DEFAULT_LAYER)
|
120 |
|
121 |
title_classify = "Image Based Plant Disease Identification ππ€"
|
122 |
|
@@ -147,7 +142,7 @@ attention_interface = gr.Interface(
|
|
147 |
fn=plot_attention,
|
148 |
inputs=[gr.Image(type="pil", label="Image"),
|
149 |
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
150 |
-
|
151 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
152 |
examples=example_list,
|
153 |
title=title_attention,
|
|
|
79 |
probs = torch.softmax(logits, dim=0)
|
80 |
topk_prob, topk_id = torch.topk(probs, 3)
|
81 |
|
|
|
|
|
|
|
82 |
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
83 |
|
84 |
def plot_attention(image, layer_num):
|
85 |
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
86 |
+
attention_map_outputs = []
|
87 |
input_data = data_transforms(image)
|
|
|
|
|
88 |
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
89 |
predict_tensor(img_transformed)
|
90 |
attn = attn_call.locals['attn'][0]
|
|
|
95 |
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
96 |
h_weights = h_weights[1:] # skip the [class] token
|
97 |
attention_block_num += 1
|
98 |
+
output_img = plot_weights(input_data, h_weights, attention_block_num)
|
99 |
+
attention_map_outputs.append(output_img)
|
100 |
+
return attention_map_outputs
|
101 |
|
102 |
def plot_weights(input_data, patch_weights, num_attention_block):
|
103 |
"""Display the image: Brighter the patch, higher is the attention."""
|
|
|
109 |
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
|
110 |
attn_map_img = inv_transform(plot, normalize=False)
|
111 |
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
112 |
+
return attn_map_img
|
113 |
|
114 |
+
attention_maps = plot_attention(random_image, 6)
|
|
|
115 |
|
116 |
title_classify = "Image Based Plant Disease Identification ππ€"
|
117 |
|
|
|
142 |
fn=plot_attention,
|
143 |
inputs=[gr.Image(type="pil", label="Image"),
|
144 |
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
145 |
+
label="Attention Layer", value="6")],
|
146 |
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
147 |
examples=example_list,
|
148 |
title=title_attention,
|