watchtowerss commited on
Commit
71ce351
1 Parent(s): 31fdbec

add max memory usage limit and add highlighted text.

Browse files
Files changed (1) hide show
  1. app.py +37 -35
app.py CHANGED
@@ -16,6 +16,7 @@ import torch
16
  from tools.interact_tools import SamControler
17
  from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
 
19
  try:
20
  from mmcv.cnn import ConvModule
21
  except:
@@ -69,6 +70,7 @@ def get_prompt(click_state, click_input):
69
  }
70
  return prompt
71
 
 
72
  # extract frames from upload video
73
  def get_frames_from_video(video_input, video_state):
74
  """
@@ -80,13 +82,20 @@ def get_frames_from_video(video_input, video_state):
80
  """
81
  video_path = video_input
82
  frames = []
 
 
83
  try:
84
  cap = cv2.VideoCapture(video_path)
85
  fps = cap.get(cv2.CAP_PROP_FPS)
86
  while cap.isOpened():
87
  ret, frame = cap.read()
88
  if ret == True:
 
89
  frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
 
 
 
90
  else:
91
  break
92
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
@@ -103,11 +112,10 @@ def get_frames_from_video(video_input, video_state):
103
  "fps": fps
104
  }
105
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
106
- operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
107
  model.samcontroler.sam_controler.reset_image()
108
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
109
  return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
110
- gr.update(visible=True), gr.update(visible=True), \
111
  gr.update(visible=True), gr.update(visible=True), \
112
  gr.update(visible=True), gr.update(visible=True), \
113
  gr.update(visible=True), gr.update(visible=True), \
@@ -131,14 +139,14 @@ def select_template(image_selection_slider, video_state, interactive_state):
131
  # update the masks when select a new template frame
132
  # if video_state["masks"][image_selection_slider] is not None:
133
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
134
- operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
135
 
136
  return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
137
 
138
  # set the tracking end frame
139
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
140
  interactive_state["track_end_number"] = track_pause_number_slider
141
- operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
142
 
143
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
144
 
@@ -177,30 +185,33 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
177
  video_state["logits"][video_state["select_frame_number"]] = logit
178
  video_state["painted_images"][video_state["select_frame_number"]] = painted_image
179
 
180
- operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
181
  return painted_image, video_state, interactive_state, operation_log
182
 
183
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
184
- mask = video_state["masks"][video_state["select_frame_number"]]
185
- interactive_state["multi_mask"]["masks"].append(mask)
186
- interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
187
- mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
188
- select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
 
189
 
190
- operation_log = "Added a mask, use the mask select for target tracking or inpainting."
 
 
191
  return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
192
 
193
  def clear_click(video_state, click_state):
194
  click_state = [[],[]]
195
  template_frame = video_state["origin_images"][video_state["select_frame_number"]]
196
- operation_log = "Clear points history and refresh the image."
197
  return template_frame, click_state, operation_log
198
 
199
  def remove_multi_mask(interactive_state, mask_dropdown):
200
  interactive_state["multi_mask"]["mask_names"]= []
201
  interactive_state["multi_mask"]["masks"] = []
202
 
203
- operation_log = "Remove all mask, please add new masks"
204
  return interactive_state, gr.update(choices=[],value=[]), operation_log
205
 
206
  def show_mask(video_state, interactive_state, mask_dropdown):
@@ -212,12 +223,12 @@ def show_mask(video_state, interactive_state, mask_dropdown):
212
  mask = interactive_state["multi_mask"]["masks"][mask_number]
213
  select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
214
 
215
- operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
216
  return select_frame, operation_log
217
 
218
  # tracking vos
219
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
220
- operation_log = "Track the selected masks, and then you can select the masks for inpainting."
221
  model.xmem.clear_memory()
222
  if interactive_state["track_end_number"]:
223
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -240,7 +251,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
240
  # operation error
241
  if len(np.unique(template_mask))==1:
242
  template_mask[0][0]=1
243
- operation_log = "Error! Please add at least one mask to track by clicking the left image."
244
  # return video_output, video_state, interactive_state, operation_error
245
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
246
  # clear GPU memory
@@ -284,7 +295,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
284
 
285
  # inpaint
286
  def inpaint_video(video_state, interactive_state, mask_dropdown):
287
- operation_log = "Removed the selected masks."
288
 
289
  frames = np.asarray(video_state["origin_images"])
290
  fps = video_state["fps"]
@@ -306,7 +317,7 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
306
  try:
307
  inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
308
  except:
309
- operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
310
  inpainted_frames = video_state["origin_images"]
311
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
312
 
@@ -364,7 +375,7 @@ folder ="./checkpoints"
364
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
365
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
366
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
367
- # args.port = 12212
368
  # args.device = "cuda:2"
369
  # args.mask_save = True
370
 
@@ -439,13 +450,7 @@ with gr.Blocks() as iface:
439
  label="Point Prompt",
440
  interactive=True,
441
  visible=False)
442
- click_mode = gr.Radio(
443
- choices=["Continuous", "Single"],
444
- value="Continuous",
445
- label="Clicking Mode",
446
- interactive=True,
447
- visible=False)
448
- with gr.Row():
449
  clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
450
  Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
451
  template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
@@ -453,13 +458,12 @@ with gr.Blocks() as iface:
453
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
454
 
455
  with gr.Column():
 
456
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
457
- remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
458
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
459
  with gr.Row():
460
  tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
461
  inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
462
- run_status = gr.Textbox(label="Operation log", visible=False)
463
 
464
  # first step: get the video information
465
  extract_frames_button.click(
@@ -468,7 +472,7 @@ with gr.Blocks() as iface:
468
  video_input, video_state
469
  ],
470
  outputs=[video_state, video_info, template_frame,
471
- image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
472
  tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
473
  )
474
 
@@ -551,7 +555,7 @@ with gr.Blocks() as iface:
551
  [[],[]],
552
  None,
553
  None,
554
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
555
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
556
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
557
  gr.update(visible=False), gr.update(visible=False)
@@ -564,7 +568,7 @@ with gr.Blocks() as iface:
564
  click_state,
565
  video_output,
566
  template_frame,
567
- tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
568
  Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
569
  ],
570
  queue=False,
@@ -589,7 +593,5 @@ with gr.Blocks() as iface:
589
  # cache_examples=True,
590
  )
591
  iface.queue(concurrency_count=1)
592
- iface.launch(debug=True)
593
-
594
-
595
-
 
16
  from tools.interact_tools import SamControler
17
  from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
19
+ import psutil
20
  try:
21
  from mmcv.cnn import ConvModule
22
  except:
 
70
  }
71
  return prompt
72
 
73
+
74
  # extract frames from upload video
75
  def get_frames_from_video(video_input, video_state):
76
  """
 
82
  """
83
  video_path = video_input
84
  frames = []
85
+
86
+ operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
87
  try:
88
  cap = cv2.VideoCapture(video_path)
89
  fps = cap.get(cv2.CAP_PROP_FPS)
90
  while cap.isOpened():
91
  ret, frame = cap.read()
92
  if ret == True:
93
+ current_memory_usage = psutil.virtual_memory().percent
94
  frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
95
+ if current_memory_usage > 90:
96
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
97
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
98
+ break
99
  else:
100
  break
101
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
 
112
  "fps": fps
113
  }
114
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
 
115
  model.samcontroler.sam_controler.reset_image()
116
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
117
  return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
118
+ gr.update(visible=True),\
119
  gr.update(visible=True), gr.update(visible=True), \
120
  gr.update(visible=True), gr.update(visible=True), \
121
  gr.update(visible=True), gr.update(visible=True), \
 
139
  # update the masks when select a new template frame
140
  # if video_state["masks"][image_selection_slider] is not None:
141
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
142
+ operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
143
 
144
  return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
145
 
146
  # set the tracking end frame
147
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
148
  interactive_state["track_end_number"] = track_pause_number_slider
149
+ operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
150
 
151
  return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
152
 
 
185
  video_state["logits"][video_state["select_frame_number"]] = logit
186
  video_state["painted_images"][video_state["select_frame_number"]] = painted_image
187
 
188
+ operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
189
  return painted_image, video_state, interactive_state, operation_log
190
 
191
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
192
+ try:
193
+ mask = video_state["masks"][video_state["select_frame_number"]]
194
+ interactive_state["multi_mask"]["masks"].append(mask)
195
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
196
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
197
+ select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
198
 
199
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
200
+ except:
201
+ operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
202
  return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
203
 
204
  def clear_click(video_state, click_state):
205
  click_state = [[],[]]
206
  template_frame = video_state["origin_images"][video_state["select_frame_number"]]
207
+ operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
208
  return template_frame, click_state, operation_log
209
 
210
  def remove_multi_mask(interactive_state, mask_dropdown):
211
  interactive_state["multi_mask"]["mask_names"]= []
212
  interactive_state["multi_mask"]["masks"] = []
213
 
214
+ operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
215
  return interactive_state, gr.update(choices=[],value=[]), operation_log
216
 
217
  def show_mask(video_state, interactive_state, mask_dropdown):
 
223
  mask = interactive_state["multi_mask"]["masks"][mask_number]
224
  select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
225
 
226
+ operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
227
  return select_frame, operation_log
228
 
229
  # tracking vos
230
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
231
+ operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
232
  model.xmem.clear_memory()
233
  if interactive_state["track_end_number"]:
234
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
 
251
  # operation error
252
  if len(np.unique(template_mask))==1:
253
  template_mask[0][0]=1
254
+ operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
255
  # return video_output, video_state, interactive_state, operation_error
256
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
257
  # clear GPU memory
 
295
 
296
  # inpaint
297
  def inpaint_video(video_state, interactive_state, mask_dropdown):
298
+ operation_log = [("",""), ("Removed the selected masks.","Normal")]
299
 
300
  frames = np.asarray(video_state["origin_images"])
301
  fps = video_state["fps"]
 
317
  try:
318
  inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
319
  except:
320
+ operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
321
  inpainted_frames = video_state["origin_images"]
322
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
323
 
 
375
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
376
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
377
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
378
+ # args.port = 12214
379
  # args.device = "cuda:2"
380
  # args.mask_save = True
381
 
 
450
  label="Point Prompt",
451
  interactive=True,
452
  visible=False)
453
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
 
 
 
 
 
 
454
  clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
455
  Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
456
  template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
 
458
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
459
 
460
  with gr.Column():
461
+ run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
462
  mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
 
463
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
464
  with gr.Row():
465
  tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
466
  inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
 
467
 
468
  # first step: get the video information
469
  extract_frames_button.click(
 
472
  video_input, video_state
473
  ],
474
  outputs=[video_state, video_info, template_frame,
475
+ image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
476
  tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
477
  )
478
 
 
555
  [[],[]],
556
  None,
557
  None,
558
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
559
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
560
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
561
  gr.update(visible=False), gr.update(visible=False)
 
568
  click_state,
569
  video_output,
570
  template_frame,
571
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
572
  Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
573
  ],
574
  queue=False,
 
593
  # cache_examples=True,
594
  )
595
  iface.queue(concurrency_count=1)
596
+ # iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
597
+ iface.launch(debug=True, enable_queue=True)