TexR6 commited on
Commit
b6c588d
1 Parent(s): a2933e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -87,6 +87,8 @@ random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RG
87
  def plot_attention(image, layer_num):
88
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
89
  input_data = data_transforms(image)
 
 
90
  with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
91
  predict_tensor(img_transformed)
92
  attn = attn_call.locals['attn'][0]
@@ -130,7 +132,7 @@ article_attention = """From the dropdown menu, choose the Encoder layer whose at
130
 
131
  classify_interface = gr.Interface(
132
  fn=predict_disease,
133
- inputs=gr.Image(type="pil", label="Image", value="examples/TomatoYellowCurlVirus6.JPG"),
134
  outputs=[gr.Label(num_top_classes=3, label="Predictions"),
135
  gr.Number(label="Prediction time (secs)")],
136
  examples=example_list,
@@ -142,8 +144,9 @@ classify_interface = gr.Interface(
142
 
143
  attention_interface = gr.Interface(
144
  fn=plot_attention,
145
- inputs=[gr.Image(type="pil", label="Image", value="examples/TomatoYellowCurlVirus6.JPG"),
146
- gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], label="Attention Layer", value="6")],
 
147
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
148
  examples=example_list,
149
  title=title_attention,
 
87
  def plot_attention(image, layer_num):
88
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
89
  input_data = data_transforms(image)
90
+ print(input_data.shape)
91
+ print(layer_num)
92
  with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
93
  predict_tensor(img_transformed)
94
  attn = attn_call.locals['attn'][0]
 
132
 
133
  classify_interface = gr.Interface(
134
  fn=predict_disease,
135
+ inputs=gr.Image(type="pil", label="Image"),
136
  outputs=[gr.Label(num_top_classes=3, label="Predictions"),
137
  gr.Number(label="Prediction time (secs)")],
138
  examples=example_list,
 
144
 
145
  attention_interface = gr.Interface(
146
  fn=plot_attention,
147
+ inputs=[gr.Image(type="pil", label="Image"),
148
+ gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
149
+ value="6", label="Attention Layer")],
150
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
151
  examples=example_list,
152
  title=title_attention,