Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
•
71ce351
1
Parent(s):
31fdbec
add max memory usage limit and add highlighted text.
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ import torch
|
|
16 |
from tools.interact_tools import SamControler
|
17 |
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
|
|
19 |
try:
|
20 |
from mmcv.cnn import ConvModule
|
21 |
except:
|
@@ -69,6 +70,7 @@ def get_prompt(click_state, click_input):
|
|
69 |
}
|
70 |
return prompt
|
71 |
|
|
|
72 |
# extract frames from upload video
|
73 |
def get_frames_from_video(video_input, video_state):
|
74 |
"""
|
@@ -80,13 +82,20 @@ def get_frames_from_video(video_input, video_state):
|
|
80 |
"""
|
81 |
video_path = video_input
|
82 |
frames = []
|
|
|
|
|
83 |
try:
|
84 |
cap = cv2.VideoCapture(video_path)
|
85 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
86 |
while cap.isOpened():
|
87 |
ret, frame = cap.read()
|
88 |
if ret == True:
|
|
|
89 |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
|
|
|
|
|
|
|
90 |
else:
|
91 |
break
|
92 |
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
@@ -103,11 +112,10 @@ def get_frames_from_video(video_input, video_state):
|
|
103 |
"fps": fps
|
104 |
}
|
105 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
106 |
-
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
|
107 |
model.samcontroler.sam_controler.reset_image()
|
108 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
109 |
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
110 |
-
gr.update(visible=True)
|
111 |
gr.update(visible=True), gr.update(visible=True), \
|
112 |
gr.update(visible=True), gr.update(visible=True), \
|
113 |
gr.update(visible=True), gr.update(visible=True), \
|
@@ -131,14 +139,14 @@ def select_template(image_selection_slider, video_state, interactive_state):
|
|
131 |
# update the masks when select a new template frame
|
132 |
# if video_state["masks"][image_selection_slider] is not None:
|
133 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
134 |
-
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
|
135 |
|
136 |
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
|
137 |
|
138 |
# set the tracking end frame
|
139 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
140 |
interactive_state["track_end_number"] = track_pause_number_slider
|
141 |
-
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
|
142 |
|
143 |
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
|
144 |
|
@@ -177,30 +185,33 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
177 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
178 |
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
179 |
|
180 |
-
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
|
181 |
return painted_image, video_state, interactive_state, operation_log
|
182 |
|
183 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
189 |
|
190 |
-
|
|
|
|
|
191 |
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
|
192 |
|
193 |
def clear_click(video_state, click_state):
|
194 |
click_state = [[],[]]
|
195 |
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
196 |
-
operation_log = "Clear points history and refresh the image."
|
197 |
return template_frame, click_state, operation_log
|
198 |
|
199 |
def remove_multi_mask(interactive_state, mask_dropdown):
|
200 |
interactive_state["multi_mask"]["mask_names"]= []
|
201 |
interactive_state["multi_mask"]["masks"] = []
|
202 |
|
203 |
-
operation_log = "Remove all mask, please add new masks"
|
204 |
return interactive_state, gr.update(choices=[],value=[]), operation_log
|
205 |
|
206 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
@@ -212,12 +223,12 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
212 |
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
213 |
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
214 |
|
215 |
-
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
|
216 |
return select_frame, operation_log
|
217 |
|
218 |
# tracking vos
|
219 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
220 |
-
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
|
221 |
model.xmem.clear_memory()
|
222 |
if interactive_state["track_end_number"]:
|
223 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
@@ -240,7 +251,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
240 |
# operation error
|
241 |
if len(np.unique(template_mask))==1:
|
242 |
template_mask[0][0]=1
|
243 |
-
operation_log = "Error! Please add at least one mask to track by clicking the left image."
|
244 |
# return video_output, video_state, interactive_state, operation_error
|
245 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
246 |
# clear GPU memory
|
@@ -284,7 +295,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
284 |
|
285 |
# inpaint
|
286 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
287 |
-
operation_log = "Removed the selected masks."
|
288 |
|
289 |
frames = np.asarray(video_state["origin_images"])
|
290 |
fps = video_state["fps"]
|
@@ -306,7 +317,7 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
306 |
try:
|
307 |
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
|
308 |
except:
|
309 |
-
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
|
310 |
inpainted_frames = video_state["origin_images"]
|
311 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
312 |
|
@@ -364,7 +375,7 @@ folder ="./checkpoints"
|
|
364 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
365 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
366 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
367 |
-
# args.port =
|
368 |
# args.device = "cuda:2"
|
369 |
# args.mask_save = True
|
370 |
|
@@ -439,13 +450,7 @@ with gr.Blocks() as iface:
|
|
439 |
label="Point Prompt",
|
440 |
interactive=True,
|
441 |
visible=False)
|
442 |
-
|
443 |
-
choices=["Continuous", "Single"],
|
444 |
-
value="Continuous",
|
445 |
-
label="Clicking Mode",
|
446 |
-
interactive=True,
|
447 |
-
visible=False)
|
448 |
-
with gr.Row():
|
449 |
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
450 |
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
451 |
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
@@ -453,13 +458,12 @@ with gr.Blocks() as iface:
|
|
453 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
454 |
|
455 |
with gr.Column():
|
|
|
456 |
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
457 |
-
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
458 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
459 |
with gr.Row():
|
460 |
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
461 |
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
462 |
-
run_status = gr.Textbox(label="Operation log", visible=False)
|
463 |
|
464 |
# first step: get the video information
|
465 |
extract_frames_button.click(
|
@@ -468,7 +472,7 @@ with gr.Blocks() as iface:
|
|
468 |
video_input, video_state
|
469 |
],
|
470 |
outputs=[video_state, video_info, template_frame,
|
471 |
-
image_selection_slider, track_pause_number_slider,point_prompt,
|
472 |
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
|
473 |
)
|
474 |
|
@@ -551,7 +555,7 @@ with gr.Blocks() as iface:
|
|
551 |
[[],[]],
|
552 |
None,
|
553 |
None,
|
554 |
-
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
555 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
556 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
557 |
gr.update(visible=False), gr.update(visible=False)
|
@@ -564,7 +568,7 @@ with gr.Blocks() as iface:
|
|
564 |
click_state,
|
565 |
video_output,
|
566 |
template_frame,
|
567 |
-
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt,
|
568 |
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
|
569 |
],
|
570 |
queue=False,
|
@@ -589,7 +593,5 @@ with gr.Blocks() as iface:
|
|
589 |
# cache_examples=True,
|
590 |
)
|
591 |
iface.queue(concurrency_count=1)
|
592 |
-
iface.launch(debug=True)
|
593 |
-
|
594 |
-
|
595 |
-
|
|
|
16 |
from tools.interact_tools import SamControler
|
17 |
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
19 |
+
import psutil
|
20 |
try:
|
21 |
from mmcv.cnn import ConvModule
|
22 |
except:
|
|
|
70 |
}
|
71 |
return prompt
|
72 |
|
73 |
+
|
74 |
# extract frames from upload video
|
75 |
def get_frames_from_video(video_input, video_state):
|
76 |
"""
|
|
|
82 |
"""
|
83 |
video_path = video_input
|
84 |
frames = []
|
85 |
+
|
86 |
+
operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
|
87 |
try:
|
88 |
cap = cv2.VideoCapture(video_path)
|
89 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
90 |
while cap.isOpened():
|
91 |
ret, frame = cap.read()
|
92 |
if ret == True:
|
93 |
+
current_memory_usage = psutil.virtual_memory().percent
|
94 |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
95 |
+
if current_memory_usage > 90:
|
96 |
+
operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
|
97 |
+
print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
|
98 |
+
break
|
99 |
else:
|
100 |
break
|
101 |
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
|
|
112 |
"fps": fps
|
113 |
}
|
114 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
|
|
115 |
model.samcontroler.sam_controler.reset_image()
|
116 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
117 |
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
118 |
+
gr.update(visible=True),\
|
119 |
gr.update(visible=True), gr.update(visible=True), \
|
120 |
gr.update(visible=True), gr.update(visible=True), \
|
121 |
gr.update(visible=True), gr.update(visible=True), \
|
|
|
139 |
# update the masks when select a new template frame
|
140 |
# if video_state["masks"][image_selection_slider] is not None:
|
141 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
142 |
+
operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
|
143 |
|
144 |
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
|
145 |
|
146 |
# set the tracking end frame
|
147 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
148 |
interactive_state["track_end_number"] = track_pause_number_slider
|
149 |
+
operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
|
150 |
|
151 |
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
|
152 |
|
|
|
185 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
186 |
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
187 |
|
188 |
+
operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
|
189 |
return painted_image, video_state, interactive_state, operation_log
|
190 |
|
191 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
192 |
+
try:
|
193 |
+
mask = video_state["masks"][video_state["select_frame_number"]]
|
194 |
+
interactive_state["multi_mask"]["masks"].append(mask)
|
195 |
+
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
196 |
+
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
197 |
+
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
|
198 |
|
199 |
+
operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
|
200 |
+
except:
|
201 |
+
operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
|
202 |
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
|
203 |
|
204 |
def clear_click(video_state, click_state):
|
205 |
click_state = [[],[]]
|
206 |
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
207 |
+
operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
|
208 |
return template_frame, click_state, operation_log
|
209 |
|
210 |
def remove_multi_mask(interactive_state, mask_dropdown):
|
211 |
interactive_state["multi_mask"]["mask_names"]= []
|
212 |
interactive_state["multi_mask"]["masks"] = []
|
213 |
|
214 |
+
operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
|
215 |
return interactive_state, gr.update(choices=[],value=[]), operation_log
|
216 |
|
217 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
|
|
223 |
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
224 |
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
225 |
|
226 |
+
operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
|
227 |
return select_frame, operation_log
|
228 |
|
229 |
# tracking vos
|
230 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
231 |
+
operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
|
232 |
model.xmem.clear_memory()
|
233 |
if interactive_state["track_end_number"]:
|
234 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
251 |
# operation error
|
252 |
if len(np.unique(template_mask))==1:
|
253 |
template_mask[0][0]=1
|
254 |
+
operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
|
255 |
# return video_output, video_state, interactive_state, operation_error
|
256 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
257 |
# clear GPU memory
|
|
|
295 |
|
296 |
# inpaint
|
297 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
298 |
+
operation_log = [("",""), ("Removed the selected masks.","Normal")]
|
299 |
|
300 |
frames = np.asarray(video_state["origin_images"])
|
301 |
fps = video_state["fps"]
|
|
|
317 |
try:
|
318 |
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
|
319 |
except:
|
320 |
+
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
|
321 |
inpainted_frames = video_state["origin_images"]
|
322 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
323 |
|
|
|
375 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
376 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
377 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
378 |
+
# args.port = 12214
|
379 |
# args.device = "cuda:2"
|
380 |
# args.mask_save = True
|
381 |
|
|
|
450 |
label="Point Prompt",
|
451 |
interactive=True,
|
452 |
visible=False)
|
453 |
+
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
455 |
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
456 |
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
|
|
458 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
459 |
|
460 |
with gr.Column():
|
461 |
+
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
|
462 |
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
|
|
463 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
464 |
with gr.Row():
|
465 |
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
466 |
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
|
|
467 |
|
468 |
# first step: get the video information
|
469 |
extract_frames_button.click(
|
|
|
472 |
video_input, video_state
|
473 |
],
|
474 |
outputs=[video_state, video_info, template_frame,
|
475 |
+
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
|
476 |
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
|
477 |
)
|
478 |
|
|
|
555 |
[[],[]],
|
556 |
None,
|
557 |
None,
|
558 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
559 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
560 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
561 |
gr.update(visible=False), gr.update(visible=False)
|
|
|
568 |
click_state,
|
569 |
video_output,
|
570 |
template_frame,
|
571 |
+
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
|
572 |
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
|
573 |
],
|
574 |
queue=False,
|
|
|
593 |
# cache_examples=True,
|
594 |
)
|
595 |
iface.queue(concurrency_count=1)
|
596 |
+
# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
597 |
+
iface.launch(debug=True, enable_queue=True)
|
|
|
|