sneha
commited on
Commit
•
5161efd
1
Parent(s):
ba48df9
fix repo_id
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ MODEL_DIR=os.path.join(os.path.dirname(eai_filepath),'model_ckpts')
|
|
18 |
if not os.path.isdir(MODEL_DIR):
|
19 |
os.mkdir(MODEL_DIR)
|
20 |
|
21 |
-
|
22 |
FILENAME = "config.yaml"
|
23 |
BASE_MODEL_TUPLE = None
|
24 |
LARGE_MODEL_TUPLE = None
|
@@ -61,10 +61,10 @@ def download_bin(model):
|
|
61 |
bin_file = 'vc1_vitb.pth'
|
62 |
else:
|
63 |
raise NameError("model not found: " + model)
|
64 |
-
|
65 |
bin_path = os.path.join(MODEL_DIR,bin_file)
|
66 |
if not os.path.isfile(bin_path):
|
67 |
-
model_bin = hf_hub_download(repo_id=
|
68 |
os.rename(model_bin, bin_path)
|
69 |
|
70 |
|
@@ -101,10 +101,12 @@ input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Hea
|
|
101 |
output_img = gr.Image(shape=(250,250))
|
102 |
output_plot = gr.Plot()
|
103 |
|
|
|
|
|
104 |
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
|
105 |
The user can decide how the attention heads will be combined. \
|
106 |
Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
|
107 |
demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
|
108 |
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
|
109 |
-
inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot])
|
110 |
-
demo.launch()
|
|
|
18 |
if not os.path.isdir(MODEL_DIR):
|
19 |
os.mkdir(MODEL_DIR)
|
20 |
|
21 |
+
|
22 |
FILENAME = "config.yaml"
|
23 |
BASE_MODEL_TUPLE = None
|
24 |
LARGE_MODEL_TUPLE = None
|
|
|
61 |
bin_file = 'vc1_vitb.pth'
|
62 |
else:
|
63 |
raise NameError("model not found: " + model)
|
64 |
+
repo_name = 'facebook/' + model
|
65 |
bin_path = os.path.join(MODEL_DIR,bin_file)
|
66 |
if not os.path.isfile(bin_path):
|
67 |
+
model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
|
68 |
os.rename(model_bin, bin_path)
|
69 |
|
70 |
|
|
|
101 |
output_img = gr.Image(shape=(250,250))
|
102 |
output_plot = gr.Plot()
|
103 |
|
104 |
+
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
|
105 |
+
|
106 |
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
|
107 |
The user can decide how the attention heads will be combined. \
|
108 |
Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
|
109 |
demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
|
110 |
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
|
111 |
+
inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot],css=css)
|
112 |
+
demo.launch(share=True)
|