TexR6 commited on
Commit
9efe5cc
1 Parent(s): 40d3684

removed encoder layer selection feature (for now)

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -81,11 +81,11 @@ def predict_tensor(img_tensor):
81
 
82
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
83
 
84
- def plot_attention(image, layer_num):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
86
  attention_map_outputs = []
87
  input_data = data_transforms(image)
88
- with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
89
  predict_tensor(img_transformed)
90
  attn = attn_call.locals['attn'][0]
91
  with torch.inference_mode():
@@ -138,8 +138,7 @@ classify_interface = gr.Interface(
138
 
139
  attention_interface = gr.Interface(
140
  fn=plot_attention,
141
- inputs=[gr.Image(type="pil", label="Image"),
142
- gr.Slider(1, 12, step=1, label="Transformer Encoder Layer")],
143
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
144
  examples=example_list,
145
  title=title_attention,
 
81
 
82
  random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
83
 
84
+ def plot_attention(image):
85
  """Given an input image, plot the average attention weight given to each image patch by each attention head."""
86
  attention_map_outputs = []
87
  input_data = data_transforms(image)
88
+ with nopdb.capture_call(vision_transformer.blocks[int(6)-1].attn.forward) as attn_call:
89
  predict_tensor(img_transformed)
90
  attn = attn_call.locals['attn'][0]
91
  with torch.inference_mode():
 
138
 
139
  attention_interface = gr.Interface(
140
  fn=plot_attention,
141
+ inputs=gr.Image(type="pil", label="Image"),
 
142
  outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
143
  examples=example_list,
144
  title=title_attention,