Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,8 @@ from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
|
20 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
21 |
from zim.utils import show_mat_anns
|
22 |
|
|
|
|
|
23 |
def get_shortest_axis(image):
|
24 |
h, w, _ = image.shape
|
25 |
return h if h < w else w
|
@@ -233,12 +235,18 @@ def get_examples():
|
|
233 |
images = os.listdir(assets_dir)
|
234 |
return [os.path.join(assets_dir, img) for img in images]
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
if __name__ == "__main__":
|
237 |
backbone = "vit_b"
|
238 |
|
239 |
# load ZIM
|
240 |
-
|
241 |
-
zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
|
242 |
if torch.cuda.is_available():
|
243 |
zim.cuda()
|
244 |
zim_predictor = ZimPredictor(zim)
|
|
|
20 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
21 |
from zim.utils import show_mat_anns
|
22 |
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
|
25 |
def get_shortest_axis(image):
|
26 |
h, w, _ = image.shape
|
27 |
return h if h < w else w
|
|
|
235 |
images = os.listdir(assets_dir)
|
236 |
return [os.path.join(assets_dir, img) for img in images]
|
237 |
|
238 |
+
def download_onnx_weights(repo_id="naver-iv/zim-anything-vitb", file_dir="zim_vit_b_2043"):
|
239 |
+
hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx")
|
240 |
+
filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx")
|
241 |
+
|
242 |
+
return os.path.dirname(filepath)
|
243 |
+
|
244 |
+
|
245 |
if __name__ == "__main__":
|
246 |
backbone = "vit_b"
|
247 |
|
248 |
# load ZIM
|
249 |
+
zim = zim_model_registry[backbone](checkpoint=download_onnx_weights())
|
|
|
250 |
if torch.cuda.is_available():
|
251 |
zim.cuda()
|
252 |
zim_predictor = ZimPredictor(zim)
|