TexR6 commited on
Commit
fb71b6d
β€’
1 Parent(s): 65c82bc

fixed file saving issues

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -79,16 +79,12 @@ def predict_tensor(img_tensor):
79
  probs = torch.softmax(logits, dim=0)
80
  topk_prob, topk_id = torch.topk(probs, 3)
81
 
82
- with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
83
- predict_tensor(img_transformed)
84
-
85
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
86
 
87
  def plot_attention(image, layer_num):
88
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
 
89
  input_data = data_transforms(image)
90
- print(input_data.shape)
91
- print(layer_num)
92
  with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
93
  predict_tensor(img_transformed)
94
  attn = attn_call.locals['attn'][0]
@@ -99,9 +95,9 @@ def plot_attention(image, layer_num):
99
  h_weights = h_weights.mean(axis=-2) # average over all attention keys
100
  h_weights = h_weights[1:] # skip the [class] token
101
  attention_block_num += 1
102
- plot_weights(input_data, h_weights, attention_block_num)
103
- attention_maps = glob.glob('storage/*.png')
104
- return attention_maps
105
 
106
  def plot_weights(input_data, patch_weights, num_attention_block):
107
  """Display the image: Brighter the patch, higher is the attention."""
@@ -113,10 +109,9 @@ def plot_weights(input_data, patch_weights, num_attention_block):
113
  plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
114
  attn_map_img = inv_transform(plot, normalize=False)
115
  attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
116
- attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
117
 
118
- DEFAULT_LAYER = "6"
119
- attention_maps = plot_attention(random_image, DEFAULT_LAYER)
120
 
121
  title_classify = "Image Based Plant Disease Identification πŸƒπŸ€“"
122
 
@@ -147,7 +142,7 @@ attention_interface = gr.Interface(
147
  fn=plot_attention,
148
  inputs=[gr.Image(type="pil", label="Image"),
149
  gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
150
- value=DEFAULT_LAYER, label="Attention Layer")],
151
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
152
  examples=example_list,
153
  title=title_attention,
 
79
  probs = torch.softmax(logits, dim=0)
80
  topk_prob, topk_id = torch.topk(probs, 3)
81
 
 
 
 
82
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
83
 
84
  def plot_attention(image, layer_num):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
86
+ attention_map_outputs = []
87
  input_data = data_transforms(image)
 
 
88
  with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
89
  predict_tensor(img_transformed)
90
  attn = attn_call.locals['attn'][0]
 
95
  h_weights = h_weights.mean(axis=-2) # average over all attention keys
96
  h_weights = h_weights[1:] # skip the [class] token
97
  attention_block_num += 1
98
+ output_img = plot_weights(input_data, h_weights, attention_block_num)
99
+ attention_map_outputs.append(output_img)
100
+ return attention_map_outputs
101
 
102
  def plot_weights(input_data, patch_weights, num_attention_block):
103
  """Display the image: Brighter the patch, higher is the attention."""
 
109
  plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
110
  attn_map_img = inv_transform(plot, normalize=False)
111
  attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
112
+ return attn_map_img
113
 
114
+ attention_maps = plot_attention(random_image, 6)
 
115
 
116
  title_classify = "Image Based Plant Disease Identification πŸƒπŸ€“"
117
 
 
142
  fn=plot_attention,
143
  inputs=[gr.Image(type="pil", label="Image"),
144
  gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
145
+ label="Attention Layer", value="6")],
146
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
147
  examples=example_list,
148
  title=title_attention,