sneha commited on
Commit
30ae246
1 Parent(s): 46f48ca
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -63,7 +63,7 @@ def download_bin(model):
63
  os.rename(model_bin, bin_path)
64
 
65
 
66
- def run_attn(input_img, model="vc1-base"):
67
  download_bin(model)
68
  model, embedding_dim, transform, metadata = get_model(model)
69
  if input_img.shape[0] != 3:
@@ -75,7 +75,7 @@ def run_attn(input_img, model="vc1-base"):
75
  input_img = resize_transform(input_img)
76
  x = transform(input_img)
77
 
78
- attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=0.89)
79
 
80
  y = model(x)
81
  mask = attention_rollout.get_attn_mask()
@@ -85,10 +85,12 @@ def run_attn(input_img, model="vc1-base"):
85
  model_type = gr.Dropdown(
86
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
87
  input_img = gr.Image(shape=(250,250))
 
 
88
  output_img = gr.Image(shape=(250,250))
89
  css = "#component-2, .input-image, .image-preview {height: 240px !important}"
90
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention(green) of the last layer of the transformer."
91
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Model", description=markdown,
92
- examples=[[os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
93
- inputs=[input_img,model_type],outputs=output_img,css=css)
94
  demo.launch()
63
  os.rename(model_bin, bin_path)
64
 
65
 
66
+ def run_attn(input_img, model="vc1-base",discard_ratio=0.89):
67
  download_bin(model)
68
  model, embedding_dim, transform, metadata = get_model(model)
69
  if input_img.shape[0] != 3:
75
  input_img = resize_transform(input_img)
76
  x = transform(input_img)
77
 
78
+ attention_rollout = VITAttentionGradRollout(model,head_fusion="max",discard_ratio=discard_ratio)
79
 
80
  y = model(x)
81
  mask = attention_rollout.get_attn_mask()
85
  model_type = gr.Dropdown(
86
  ["vc1-base", "vc1-large"], label="Model Size", value="vc1-base")
87
  input_img = gr.Image(shape=(250,250))
88
+ discard_ratio = gr.Slider(0,1,value=0.89)
89
+
90
  output_img = gr.Image(shape=(250,250))
91
  css = "#component-2, .input-image, .image-preview {height: 240px !important}"
92
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention(green) of the last layer of the transformer."
93
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Model", description=markdown,
94
+ examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
95
+ inputs=[input_img,model_type,discard_ratio],outputs=output_img,css=css)
96
  demo.launch()