TexR6 commited on
Commit
3f3a5b0
1 Parent(s): fb71b6d

removed redundancies

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -90,16 +90,14 @@ def plot_attention(image, layer_num):
90
  attn = attn_call.locals['attn'][0]
91
  with torch.inference_mode():
92
  # loop over attention heads
93
- attention_block_num = 0
94
  for h_weights in attn:
95
  h_weights = h_weights.mean(axis=-2) # average over all attention keys
96
  h_weights = h_weights[1:] # skip the [class] token
97
- attention_block_num += 1
98
- output_img = plot_weights(input_data, h_weights, attention_block_num)
99
  attention_map_outputs.append(output_img)
100
  return attention_map_outputs
101
 
102
- def plot_weights(input_data, patch_weights, num_attention_block):
103
  """Display the image: Brighter the patch, higher is the attention."""
104
  # multiply each patch of the input image by the corresponding weight
105
  plot = inv_normalize(input_data.clone())
@@ -111,7 +109,7 @@ def plot_weights(input_data, patch_weights, num_attention_block):
111
  attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
112
  return attn_map_img
113
 
114
- attention_maps = plot_attention(random_image, 6)
115
 
116
  title_classify = "Image Based Plant Disease Identification 🍃🤓"
117
 
@@ -152,5 +150,7 @@ attention_interface = gr.Interface(
152
  )
153
 
154
  demo = gr.TabbedInterface([classify_interface, attention_interface],
155
- ["Identify Disease", "Visualize Attention Map"],
156
- title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)
 
 
 
90
  attn = attn_call.locals['attn'][0]
91
  with torch.inference_mode():
92
  # loop over attention heads
 
93
  for h_weights in attn:
94
  h_weights = h_weights.mean(axis=-2) # average over all attention keys
95
  h_weights = h_weights[1:] # skip the [class] token
96
+ output_img = plot_weights(input_data, h_weights)
 
97
  attention_map_outputs.append(output_img)
98
  return attention_map_outputs
99
 
100
+ def plot_weights(input_data, patch_weights):
101
  """Display the image: Brighter the patch, higher is the attention."""
102
  # multiply each patch of the input image by the corresponding weight
103
  plot = inv_normalize(input_data.clone())
 
109
  attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
110
  return attn_map_img
111
 
112
+ attention_maps = plot_attention(random_image, "6")
113
 
114
  title_classify = "Image Based Plant Disease Identification 🍃🤓"
115
 
 
150
  )
151
 
152
  demo = gr.TabbedInterface([classify_interface, attention_interface],
153
+ ["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺")
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch()