Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
•
5da584d
1
Parent(s):
94dd0a9
Upload 65 files
Browse files- app.py +11 -39
- requirements.txt +5 -1
app.py
CHANGED
@@ -13,13 +13,7 @@ import requests
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
16 |
-
from tools.interact_tools import SamControler
|
17 |
-
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
19 |
-
try:
|
20 |
-
from mmcv.cnn import ConvModule
|
21 |
-
except:
|
22 |
-
os.system("mim install mmcv")
|
23 |
|
24 |
# download checkpoints
|
25 |
def download_checkpoint(url, folder, filename):
|
@@ -206,7 +200,6 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
206 |
|
207 |
# tracking vos
|
208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
209 |
-
|
210 |
model.xmem.clear_memory()
|
211 |
if interactive_state["track_end_number"]:
|
212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
@@ -226,8 +219,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
227 |
fps = video_state["fps"]
|
228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
229 |
-
# clear GPU memory
|
230 |
-
model.xmem.clear_memory()
|
231 |
|
232 |
if interactive_state["track_end_number"]:
|
233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
@@ -267,7 +258,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
267 |
|
268 |
# inpaint
|
269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
270 |
-
|
271 |
frames = np.asarray(video_state["origin_images"])
|
272 |
fps = video_state["fps"]
|
273 |
inpaint_masks = np.asarray(video_state["masks"])
|
@@ -314,44 +304,27 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
315 |
return output_path
|
316 |
|
317 |
-
|
318 |
-
# args, defined in track_anything.py
|
319 |
-
args = parse_augment()
|
320 |
-
|
321 |
# check and download checkpoints if needed
|
322 |
-
|
323 |
-
|
324 |
-
'vit_l': "sam_vit_l_0b3195.pth",
|
325 |
-
"vit_b": "sam_vit_b_01ec64.pth"
|
326 |
-
}
|
327 |
-
SAM_checkpoint_url_dict = {
|
328 |
-
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
329 |
-
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
330 |
-
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
331 |
-
}
|
332 |
-
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
333 |
-
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
334 |
xmem_checkpoint = "XMem-s012.pth"
|
335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
338 |
|
339 |
-
|
340 |
folder ="./checkpoints"
|
341 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
344 |
-
|
|
|
|
|
|
|
|
|
345 |
|
346 |
# initialize sam, xmem, e2fgvi models
|
347 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
348 |
|
349 |
-
|
350 |
-
title = """<p><h1 align="center">Track-Anything</h1></p>
|
351 |
-
"""
|
352 |
-
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
353 |
-
|
354 |
-
|
355 |
with gr.Blocks() as iface:
|
356 |
"""
|
357 |
state for
|
@@ -383,8 +356,7 @@ with gr.Blocks() as iface:
|
|
383 |
"fps": 30
|
384 |
}
|
385 |
)
|
386 |
-
|
387 |
-
gr.Markdown(description)
|
388 |
with gr.Row():
|
389 |
|
390 |
# for user video input
|
@@ -393,7 +365,7 @@ with gr.Blocks() as iface:
|
|
393 |
video_input = gr.Video(autosize=True)
|
394 |
with gr.Column():
|
395 |
video_info = gr.Textbox()
|
396 |
-
|
397 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
398 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
399 |
|
@@ -562,7 +534,7 @@ with gr.Blocks() as iface:
|
|
562 |
# cache_examples=True,
|
563 |
)
|
564 |
iface.queue(concurrency_count=1)
|
565 |
-
iface.launch(debug=True, enable_queue=True)
|
566 |
|
567 |
|
568 |
|
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
|
|
|
|
16 |
from tools.painter import mask_painter
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# download checkpoints
|
19 |
def download_checkpoint(url, folder, filename):
|
|
|
200 |
|
201 |
# tracking vos
|
202 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
203 |
model.xmem.clear_memory()
|
204 |
if interactive_state["track_end_number"]:
|
205 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
219 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
220 |
fps = video_state["fps"]
|
221 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
222 |
|
223 |
if interactive_state["track_end_number"]:
|
224 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
258 |
|
259 |
# inpaint
|
260 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
261 |
frames = np.asarray(video_state["origin_images"])
|
262 |
fps = video_state["fps"]
|
263 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
304 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
305 |
return output_path
|
306 |
|
|
|
|
|
|
|
|
|
307 |
# check and download checkpoints if needed
|
308 |
+
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
309 |
+
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
xmem_checkpoint = "XMem-s012.pth"
|
311 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
312 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
313 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
314 |
|
|
|
315 |
folder ="./checkpoints"
|
316 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
317 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
318 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
319 |
+
# args, defined in track_anything.py
|
320 |
+
args = parse_augment()
|
321 |
+
# args.port = 12315
|
322 |
+
# args.device = "cuda:2"
|
323 |
+
# args.mask_save = True
|
324 |
|
325 |
# initialize sam, xmem, e2fgvi models
|
326 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
with gr.Blocks() as iface:
|
329 |
"""
|
330 |
state for
|
|
|
356 |
"fps": 30
|
357 |
}
|
358 |
)
|
359 |
+
|
|
|
360 |
with gr.Row():
|
361 |
|
362 |
# for user video input
|
|
|
365 |
video_input = gr.Video(autosize=True)
|
366 |
with gr.Column():
|
367 |
video_info = gr.Textbox()
|
368 |
+
video_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
369 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
370 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
371 |
|
|
|
534 |
# cache_examples=True,
|
535 |
)
|
536 |
iface.queue(concurrency_count=1)
|
537 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
538 |
|
539 |
|
540 |
|
requirements.txt
CHANGED
@@ -10,6 +10,10 @@ gradio==3.25.0
|
|
10 |
opencv-python
|
11 |
pycocotools
|
12 |
matplotlib
|
|
|
|
|
|
|
13 |
pyyaml
|
14 |
av
|
15 |
-
|
|
|
|
10 |
opencv-python
|
11 |
pycocotools
|
12 |
matplotlib
|
13 |
+
onnxruntime
|
14 |
+
onnx
|
15 |
+
metaseg==0.6.1
|
16 |
pyyaml
|
17 |
av
|
18 |
+
mmcv-full
|
19 |
+
mmengine
|