sneha commited on
Commit
5161efd
1 Parent(s): ba48df9

fix repo_id

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -18,7 +18,7 @@ MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
18
  if not os.path.isdir(MODEL_DIR):
19
  os.mkdir(MODEL_DIR)
20
 
21
- REPO_ID = "facebook/vc1-base"
22
  FILENAME = "config.yaml"
23
  BASE_MODEL_TUPLE = None
24
  LARGE_MODEL_TUPLE = None
@@ -61,10 +61,10 @@ def download_bin(model):
61
  bin_file = 'vc1_vitb.pth'
62
  else:
63
  raise NameError("model not found: " + model)
64
-
65
  bin_path = os.path.join(MODEL_DIR,bin_file)
66
  if not os.path.isfile(bin_path):
67
- model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
68
  os.rename(model_bin, bin_path)
69
 
70
 
@@ -101,10 +101,12 @@ input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Hea
101
  output_img = gr.Image(shape=(250,250))
102
  output_plot = gr.Plot()
103
 
 
 
104
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
105
  The user can decide how the attention heads will be combined. \
106
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
107
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
108
  examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
109
- inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot])
110
- demo.launch()
 
18
  if not os.path.isdir(MODEL_DIR):
19
  os.mkdir(MODEL_DIR)
20
 
21
+
22
  FILENAME = "config.yaml"
23
  BASE_MODEL_TUPLE = None
24
  LARGE_MODEL_TUPLE = None
 
61
  bin_file = 'vc1_vitb.pth'
62
  else:
63
  raise NameError("model not found: " + model)
64
+ repo_name = 'facebook/' + model
65
  bin_path = os.path.join(MODEL_DIR,bin_file)
66
  if not os.path.isfile(bin_path):
67
+ model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
68
  os.rename(model_bin, bin_path)
69
 
70
 
 
101
  output_img = gr.Image(shape=(250,250))
102
  output_plot = gr.Plot()
103
 
104
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
105
+
106
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
107
  The user can decide how the attention heads will be combined. \
108
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
109
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
110
  examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
111
+ inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot],css=css)
112
+ demo.launch(share=True)