sneha commited on
Commit
c54235c
1 Parent(s): 1df99f6

allow switching between models

Browse files
Files changed (1) hide show
  1. app.py +44 -20
app.py CHANGED
@@ -20,32 +20,53 @@ if not os.path.isdir(MODEL_DIR):
20
 
21
  REPO_ID = "facebook/vc1-base"
22
  FILENAME = "config.yaml"
23
- MODEL_TUPLE = None
24
-
25
- def get_model():
26
- global MODEL_TUPLE
27
- download_bin()
28
- if MODEL_TUPLE is None:
 
 
 
 
 
 
 
 
 
 
 
 
29
  model_cfg = omegaconf.OmegaConf.load(
30
- hf_hub_download(repo_id=REPO_ID, filename=FILENAME,token=HF_TOKEN)
31
  )
32
- model_cfg['model']['checkpoint_path'] = None
33
- model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'
34
- MODEL_TUPLE = utils.instantiate(model_cfg)
35
- MODEL_TUPLE[0].eval()
36
- return MODEL_TUPLE#model,embedding_dim,transform,metadata
 
 
 
 
 
 
37
 
38
- def download_bin():
39
- bin_file = 'vc1_vitb.pth'
 
 
 
40
  bin_path = os.path.join(MODEL_DIR,bin_file)
41
  if not os.path.isfile(bin_path):
42
  model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
43
  os.rename(model_bin, bin_path)
44
 
45
 
46
- def run_attn(input_img,fusion="min"):
47
- download_bin()
48
- model, embedding_dim, transform, metadata = get_model()
49
  if input_img.shape[0] != 3:
50
  input_img = input_img.transpose(2, 0, 1)
51
  if(len(input_img.shape)== 3):
@@ -63,11 +84,14 @@ def run_attn(input_img,fusion="min"):
63
 
64
  fig = plt.figure()
65
  ax = fig.subplots()
 
66
  im = ax.matshow(y.detach().numpy().reshape(16,-1))
67
  plt.colorbar(im)
68
 
69
  return attn_img, fig
70
 
 
 
71
  input_img = gr.Image(shape=(250,250))
72
  input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
73
  output_img = gr.Image(shape=(250,250))
@@ -75,8 +99,8 @@ output_plot = gr.Plot()
75
 
76
  markdown ="This is a demo for the Visual Cortex (Base) model. When passed an image input, it displays the attention of the last layer of the transformer.\n \
77
  The user can decide how the attention heads will be combined. \
78
- Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 grid."
79
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
80
- examples=[[os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
81
- inputs=[input_img,input_button],outputs=[output_img,output_plot])
82
  demo.launch()
 
20
 
21
  REPO_ID = "facebook/vc1-base"
22
  FILENAME = "config.yaml"
23
+ BASE_MODEL_TUPLE = None
24
+ LARGE_MODEL_TUPLE = None
25
+ def get_model(model_name):
26
+ global BASE_MODEL_TUPLE,LARGE_MODEL_TUPLE
27
+ download_bin(model_name)
28
+ model = None
29
+ if BASE_MODEL_TUPLE is None and model_name == 'vc1-base':
30
+ repo_name = "facebook/" + model_name
31
+ model_cfg = omegaconf.OmegaConf.load(
32
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
33
+ )
34
+ # model_cfg['model']['checkpoint_path'] = None
35
+ # model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'
36
+ BASE_MODEL_TUPLE = utils.instantiate(model_cfg)
37
+ BASE_MODEL_TUPLE[0].eval()
38
+ model = BASE_MODEL_TUPLE
39
+ elif LARGE_MODEL_TUPLE is None and model_name == 'vc1-large':
40
+ repo_name = "facebook/" + model_name
41
  model_cfg = omegaconf.OmegaConf.load(
42
+ hf_hub_download(repo_id=repo_name, filename=FILENAME,token=HF_TOKEN)
43
  )
44
+ # model_cfg['model']['checkpoint_path'] = None
45
+ # model_cfg['model']['checkpoint_path'] = 'model_ckpts/vc1_vitb.pth'
46
+ LARGE_MODEL_TUPLE = utils.instantiate(model_cfg)
47
+ LARGE_MODEL_TUPLE[0].eval()
48
+ model = LARGE_MODEL_TUPLE
49
+ elif model_name == 'vc1-base':
50
+ model = BASE_MODEL_TUPLE
51
+ elif model_name == 'vc1-large':
52
+ model = LARGE_MODEL_TUPLE
53
+
54
+ return model #model,embedding_dim,transform,metadata
55
 
56
+ def download_bin(model):
57
+ if model == "vc1-large":
58
+ bin_file = 'vc1_vitl.pth'
59
+ elif model == "vc1-base":
60
+ bin_file = 'vc1_vitb.pth'
61
  bin_path = os.path.join(MODEL_DIR,bin_file)
62
  if not os.path.isfile(bin_path):
63
  model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
64
  os.rename(model_bin, bin_path)
65
 
66
 
67
+ def run_attn(model, input_img,fusion="min"):
68
+ download_bin(model)
69
+ model, embedding_dim, transform, metadata = get_model(model)
70
  if input_img.shape[0] != 3:
71
  input_img = input_img.transpose(2, 0, 1)
72
  if(len(input_img.shape)== 3):
 
84
 
85
  fig = plt.figure()
86
  ax = fig.subplots()
87
+ print(y.shape)
88
  im = ax.matshow(y.detach().numpy().reshape(16,-1))
89
  plt.colorbar(im)
90
 
91
  return attn_img, fig
92
 
93
+ model_type = gr.Dropdown(
94
+ ["vc1-base", "vc1-large"], label="Model Size", value="vc1-large")
95
  input_img = gr.Image(shape=(250,250))
96
  input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
97
  output_img = gr.Image(shape=(250,250))
 
99
 
100
  markdown ="This is a demo for the Visual Cortex (Base) model. When passed an image input, it displays the attention of the last layer of the transformer.\n \
101
  The user can decide how the attention heads will be combined. \
102
+ Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
103
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Base Model", description=markdown,
104
+ examples=[[None, os.path.join('./imgs',x),None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
105
+ inputs=[model_type,input_img,input_button],outputs=[output_img,output_plot])
106
  demo.launch()