Beom0 commited on
Commit
3567b7f
·
verified ·
1 Parent(s): dd6d517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -20,6 +20,8 @@ from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
20
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
21
  from zim.utils import show_mat_anns
22
 
 
 
23
  def get_shortest_axis(image):
24
  h, w, _ = image.shape
25
  return h if h < w else w
@@ -233,12 +235,18 @@ def get_examples():
233
  images = os.listdir(assets_dir)
234
  return [os.path.join(assets_dir, img) for img in images]
235
 
 
 
 
 
 
 
 
236
  if __name__ == "__main__":
237
  backbone = "vit_b"
238
 
239
  # load ZIM
240
- ckpt_mat = "ckpts/zim_vit_b_2043"
241
- zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
242
  if torch.cuda.is_available():
243
  zim.cuda()
244
  zim_predictor = ZimPredictor(zim)
 
20
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
21
  from zim.utils import show_mat_anns
22
 
23
+ from huggingface_hub import hf_hub_download
24
+
25
  def get_shortest_axis(image):
26
  h, w, _ = image.shape
27
  return h if h < w else w
 
235
  images = os.listdir(assets_dir)
236
  return [os.path.join(assets_dir, img) for img in images]
237
 
238
+ def download_onnx_weights(repo_id="naver-iv/zim-anything-vitb", file_dir="zim_vit_b_2043"):
239
+ hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx")
240
+ filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx")
241
+
242
+ return os.path.dirname(filepath)
243
+
244
+
245
  if __name__ == "__main__":
246
  backbone = "vit_b"
247
 
248
  # load ZIM
249
+ zim = zim_model_registry[backbone](checkpoint=download_onnx_weights())
 
250
  if torch.cuda.is_available():
251
  zim.cuda()
252
  zim_predictor = ZimPredictor(zim)