watchtowerss commited on
Commit
3c7c9f9
1 Parent(s): 6738b38

operation prompt version

Browse files
Files changed (2) hide show
  1. app.py +59 -36
  2. inpainter/base_inpainter.py +10 -10
app.py CHANGED
@@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state):
103
  "fps": fps
104
  }
105
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
106
-
107
  model.samcontroler.sam_controler.reset_image()
108
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
109
  return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
@@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state):
111
  gr.update(visible=True), gr.update(visible=True), \
112
  gr.update(visible=True), gr.update(visible=True), \
113
  gr.update(visible=True), gr.update(visible=True), \
114
- gr.update(visible=True), gr.update(visible=True)
 
115
 
116
  def run_example(example):
117
  return video_input
@@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state):
130
  # update the masks when select a new template frame
131
  # if video_state["masks"][image_selection_slider] is not None:
132
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
 
133
 
134
-
135
- return video_state["painted_images"][image_selection_slider], video_state, interactive_state
136
 
137
  # set the tracking end frame
138
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
139
  interactive_state["track_end_number"] = track_pause_number_slider
 
140
 
141
- return video_state["painted_images"][track_pause_number_slider],interactive_state
142
 
143
  def get_resize_ratio(resize_ratio_slider, interactive_state):
144
  interactive_state["resize_ratio"] = resize_ratio_slider
@@ -175,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
175
  video_state["logits"][video_state["select_frame_number"]] = logit
176
  video_state["painted_images"][video_state["select_frame_number"]] = painted_image
177
 
178
- return painted_image, video_state, interactive_state
 
179
 
180
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
181
  mask = video_state["masks"][video_state["select_frame_number"]]
182
  interactive_state["multi_mask"]["masks"].append(mask)
183
  interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
184
  mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
185
- select_frame = show_mask(video_state, interactive_state, mask_dropdown)
186
- return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
 
 
187
 
188
  def clear_click(video_state, click_state):
189
  click_state = [[],[]]
190
  template_frame = video_state["origin_images"][video_state["select_frame_number"]]
191
- return template_frame, click_state
 
192
 
193
- def remove_multi_mask(interactive_state):
194
  interactive_state["multi_mask"]["mask_names"]= []
195
  interactive_state["multi_mask"]["masks"] = []
196
- return interactive_state
 
 
197
 
198
  def show_mask(video_state, interactive_state, mask_dropdown):
199
  mask_dropdown.sort()
@@ -203,12 +211,13 @@ def show_mask(video_state, interactive_state, mask_dropdown):
203
  mask_number = int(mask_dropdown[i].split("_")[1]) - 1
204
  mask = interactive_state["multi_mask"]["masks"][mask_number]
205
  select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
206
-
207
- return select_frame
 
208
 
209
  # tracking vos
210
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
211
-
212
  model.xmem.clear_memory()
213
  if interactive_state["track_end_number"]:
214
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -227,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
227
  else:
228
  template_mask = video_state["masks"][video_state["select_frame_number"]]
229
  fps = video_state["fps"]
 
 
 
 
 
 
230
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
231
  # clear GPU memory
232
  model.xmem.clear_memory()
@@ -259,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
259
  i+=1
260
  # save_mask(video_state["masks"], video_state["video_name"])
261
  #### shanggao code for mask save
262
- return video_output, video_state, interactive_state
263
 
264
  # extracting masks from mask_dropdown
265
  # def extract_sole_mask(video_state, mask_dropdown):
@@ -269,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
269
 
270
  # inpaint
271
  def inpaint_video(video_state, interactive_state, mask_dropdown):
 
272
 
273
  frames = np.asarray(video_state["origin_images"])
274
  fps = video_state["fps"]
@@ -286,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
286
  continue
287
  inpaint_masks[inpaint_masks==i] = 0
288
  # inpaint for videos
289
- inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
 
 
 
 
 
290
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
291
 
292
- return video_output
293
 
294
 
295
  # generate video after vos inference
@@ -343,7 +364,7 @@ folder ="./checkpoints"
343
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
344
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
345
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
346
- # args.port = 12315
347
  # args.device = "cuda:2"
348
  # args.mask_save = True
349
 
@@ -396,9 +417,9 @@ with gr.Blocks() as iface:
396
  with gr.Row(scale=0.4):
397
  video_input = gr.Video(autosize=True)
398
  with gr.Column():
399
- video_info = gr.Textbox()
400
- resize_info = gr.Textbox(value="Due to server restrictions, please upload a video that is no longer than 2 minutes. If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
401
- Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
402
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
403
 
404
 
@@ -432,12 +453,13 @@ with gr.Blocks() as iface:
432
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
433
 
434
  with gr.Column():
435
- mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
436
  remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
437
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
438
  with gr.Row():
439
  tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
440
  inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
 
441
 
442
  # first step: get the video information
443
  extract_frames_button.click(
@@ -447,16 +469,16 @@ with gr.Blocks() as iface:
447
  ],
448
  outputs=[video_state, video_info, template_frame,
449
  image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
450
- tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button]
451
  )
452
 
453
  # second step: select images from slider
454
  image_selection_slider.release(fn=select_template,
455
  inputs=[image_selection_slider, video_state, interactive_state],
456
- outputs=[template_frame, video_state, interactive_state], api_name="select_image")
457
  track_pause_number_slider.release(fn=get_end_number,
458
  inputs=[track_pause_number_slider, video_state, interactive_state],
459
- outputs=[template_frame, interactive_state], api_name="end_image")
460
  resize_ratio_slider.release(fn=get_resize_ratio,
461
  inputs=[resize_ratio_slider, interactive_state],
462
  outputs=[interactive_state], api_name="resize_ratio")
@@ -465,41 +487,41 @@ with gr.Blocks() as iface:
465
  template_frame.select(
466
  fn=sam_refine,
467
  inputs=[video_state, point_prompt, click_state, interactive_state],
468
- outputs=[template_frame, video_state, interactive_state]
469
  )
470
 
471
  # add different mask
472
  Add_mask_button.click(
473
  fn=add_multi_mask,
474
  inputs=[video_state, interactive_state, mask_dropdown],
475
- outputs=[interactive_state, mask_dropdown, template_frame, click_state]
476
  )
477
 
478
  remove_mask_button.click(
479
  fn=remove_multi_mask,
480
- inputs=[interactive_state],
481
- outputs=[interactive_state]
482
  )
483
 
484
  # tracking video from select image and mask
485
  tracking_video_predict_button.click(
486
  fn=vos_tracking_video,
487
  inputs=[video_state, interactive_state, mask_dropdown],
488
- outputs=[video_output, video_state, interactive_state]
489
  )
490
 
491
  # inpaint video from select image and mask
492
  inpaint_video_predict_button.click(
493
  fn=inpaint_video,
494
  inputs=[video_state, interactive_state, mask_dropdown],
495
- outputs=[video_output]
496
  )
497
 
498
  # click to get mask
499
  mask_dropdown.change(
500
  fn=show_mask,
501
  inputs=[video_state, interactive_state, mask_dropdown],
502
- outputs=[template_frame]
503
  )
504
 
505
  # clear input
@@ -531,7 +553,8 @@ with gr.Blocks() as iface:
531
  None,
532
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
533
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
534
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), gr.update(visible=False) \
 
535
 
536
  ),
537
  [],
@@ -542,7 +565,7 @@ with gr.Blocks() as iface:
542
  video_output,
543
  template_frame,
544
  tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
545
- Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button
546
  ],
547
  queue=False,
548
  show_progress=False)
@@ -551,7 +574,7 @@ with gr.Blocks() as iface:
551
  clear_button_click.click(
552
  fn = clear_click,
553
  inputs = [video_state, click_state,],
554
- outputs = [template_frame,click_state],
555
  )
556
  # set example
557
  gr.Markdown("## Examples")
@@ -566,7 +589,7 @@ with gr.Blocks() as iface:
566
  # cache_examples=True,
567
  )
568
  iface.queue(concurrency_count=1)
569
- iface.launch(debug=True)
570
 
571
 
572
 
 
103
  "fps": fps
104
  }
105
  video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
106
+ operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
107
  model.samcontroler.sam_controler.reset_image()
108
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
109
  return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
 
111
  gr.update(visible=True), gr.update(visible=True), \
112
  gr.update(visible=True), gr.update(visible=True), \
113
  gr.update(visible=True), gr.update(visible=True), \
114
+ gr.update(visible=True), gr.update(visible=True), \
115
+ gr.update(visible=True, value=operation_log)
116
 
117
  def run_example(example):
118
  return video_input
 
131
  # update the masks when select a new template frame
132
  # if video_state["masks"][image_selection_slider] is not None:
133
  # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
134
+ operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
135
 
136
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
 
137
 
138
  # set the tracking end frame
139
  def get_end_number(track_pause_number_slider, video_state, interactive_state):
140
  interactive_state["track_end_number"] = track_pause_number_slider
141
+ operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
142
 
143
+ return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
144
 
145
  def get_resize_ratio(resize_ratio_slider, interactive_state):
146
  interactive_state["resize_ratio"] = resize_ratio_slider
 
177
  video_state["logits"][video_state["select_frame_number"]] = logit
178
  video_state["painted_images"][video_state["select_frame_number"]] = painted_image
179
 
180
+ operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
181
+ return painted_image, video_state, interactive_state, operation_log
182
 
183
  def add_multi_mask(video_state, interactive_state, mask_dropdown):
184
  mask = video_state["masks"][video_state["select_frame_number"]]
185
  interactive_state["multi_mask"]["masks"].append(mask)
186
  interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
187
  mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
188
+ select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
189
+
190
+ operation_log = "Added a mask, use the mask select for target tracking or inpainting."
191
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
192
 
193
  def clear_click(video_state, click_state):
194
  click_state = [[],[]]
195
  template_frame = video_state["origin_images"][video_state["select_frame_number"]]
196
+ operation_log = "Clear points history and refresh the image."
197
+ return template_frame, click_state, operation_log
198
 
199
+ def remove_multi_mask(interactive_state, mask_dropdown):
200
  interactive_state["multi_mask"]["mask_names"]= []
201
  interactive_state["multi_mask"]["masks"] = []
202
+
203
+ operation_log = "Remove all mask, please add new masks"
204
+ return interactive_state, gr.update(choices=[],value=[]), operation_log
205
 
206
  def show_mask(video_state, interactive_state, mask_dropdown):
207
  mask_dropdown.sort()
 
211
  mask_number = int(mask_dropdown[i].split("_")[1]) - 1
212
  mask = interactive_state["multi_mask"]["masks"][mask_number]
213
  select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
214
+
215
+ operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
216
+ return select_frame, operation_log
217
 
218
  # tracking vos
219
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
220
+ operation_log = "Track the selected masks, and then you can select the masks for inpainting."
221
  model.xmem.clear_memory()
222
  if interactive_state["track_end_number"]:
223
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
 
236
  else:
237
  template_mask = video_state["masks"][video_state["select_frame_number"]]
238
  fps = video_state["fps"]
239
+
240
+ # operation error
241
+ if len(np.unique(template_mask))==1:
242
+ template_mask[0][0]=1
243
+ operation_log = "Error! Please add at least one mask to track by clicking the left image."
244
+ # return video_output, video_state, interactive_state, operation_error
245
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
246
  # clear GPU memory
247
  model.xmem.clear_memory()
 
274
  i+=1
275
  # save_mask(video_state["masks"], video_state["video_name"])
276
  #### shanggao code for mask save
277
+ return video_output, video_state, interactive_state, operation_log
278
 
279
  # extracting masks from mask_dropdown
280
  # def extract_sole_mask(video_state, mask_dropdown):
 
284
 
285
  # inpaint
286
  def inpaint_video(video_state, interactive_state, mask_dropdown):
287
+ operation_log = "Removed the selected masks."
288
 
289
  frames = np.asarray(video_state["origin_images"])
290
  fps = video_state["fps"]
 
302
  continue
303
  inpaint_masks[inpaint_masks==i] = 0
304
  # inpaint for videos
305
+
306
+ try:
307
+ inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
308
+ except:
309
+ operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
310
+ inpainted_frames = video_state["origin_images"]
311
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
312
 
313
+ return video_output, operation_log
314
 
315
 
316
  # generate video after vos inference
 
364
  SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
365
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
366
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
367
+ # args.port = 12212
368
  # args.device = "cuda:2"
369
  # args.mask_save = True
370
 
 
417
  with gr.Row(scale=0.4):
418
  video_input = gr.Video(autosize=True)
419
  with gr.Column():
420
+ video_info = gr.Textbox(label="Video Info")
421
+ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
422
+ Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
423
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
424
 
425
 
 
453
  track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
454
 
455
  with gr.Column():
456
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
457
  remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
458
  video_output = gr.Video(autosize=True, visible=False).style(height=360)
459
  with gr.Row():
460
  tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
461
  inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
462
+ run_status = gr.Textbox(label="Operation log", visible=False)
463
 
464
  # first step: get the video information
465
  extract_frames_button.click(
 
469
  ],
470
  outputs=[video_state, video_info, template_frame,
471
  image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
472
+ tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
473
  )
474
 
475
  # second step: select images from slider
476
  image_selection_slider.release(fn=select_template,
477
  inputs=[image_selection_slider, video_state, interactive_state],
478
+ outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
479
  track_pause_number_slider.release(fn=get_end_number,
480
  inputs=[track_pause_number_slider, video_state, interactive_state],
481
+ outputs=[template_frame, interactive_state, run_status], api_name="end_image")
482
  resize_ratio_slider.release(fn=get_resize_ratio,
483
  inputs=[resize_ratio_slider, interactive_state],
484
  outputs=[interactive_state], api_name="resize_ratio")
 
487
  template_frame.select(
488
  fn=sam_refine,
489
  inputs=[video_state, point_prompt, click_state, interactive_state],
490
+ outputs=[template_frame, video_state, interactive_state, run_status]
491
  )
492
 
493
  # add different mask
494
  Add_mask_button.click(
495
  fn=add_multi_mask,
496
  inputs=[video_state, interactive_state, mask_dropdown],
497
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
498
  )
499
 
500
  remove_mask_button.click(
501
  fn=remove_multi_mask,
502
+ inputs=[interactive_state, mask_dropdown],
503
+ outputs=[interactive_state, mask_dropdown, run_status]
504
  )
505
 
506
  # tracking video from select image and mask
507
  tracking_video_predict_button.click(
508
  fn=vos_tracking_video,
509
  inputs=[video_state, interactive_state, mask_dropdown],
510
+ outputs=[video_output, video_state, interactive_state, run_status]
511
  )
512
 
513
  # inpaint video from select image and mask
514
  inpaint_video_predict_button.click(
515
  fn=inpaint_video,
516
  inputs=[video_state, interactive_state, mask_dropdown],
517
+ outputs=[video_output, run_status]
518
  )
519
 
520
  # click to get mask
521
  mask_dropdown.change(
522
  fn=show_mask,
523
  inputs=[video_state, interactive_state, mask_dropdown],
524
+ outputs=[template_frame, run_status]
525
  )
526
 
527
  # clear input
 
553
  None,
554
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
555
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
556
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
557
+ gr.update(visible=False), gr.update(visible=False)
558
 
559
  ),
560
  [],
 
565
  video_output,
566
  template_frame,
567
  tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
568
+ Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
569
  ],
570
  queue=False,
571
  show_progress=False)
 
574
  clear_button_click.click(
575
  fn = clear_click,
576
  inputs = [video_state, click_state,],
577
+ outputs = [template_frame,click_state, run_status],
578
  )
579
  # set example
580
  gr.Markdown("## Examples")
 
589
  # cache_examples=True,
590
  )
591
  iface.queue(concurrency_count=1)
592
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
593
 
594
 
595
 
inpainter/base_inpainter.py CHANGED
@@ -64,21 +64,21 @@ class BaseInpainter:
64
  masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
65
 
66
  T, H, W = masks.shape
 
67
  # size: (w, h)
68
  if ratio == 1:
69
  size = None
 
70
  else:
71
  size = [int(W*ratio), int(H*ratio)]
72
- if size[0] % 2 > 0:
73
- size[0] += 1
74
- if size[1] % 2 > 0:
75
- size[1] += 1
76
-
77
- masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
78
- binary_masks = resize_masks(masks, tuple(size))
79
- frames = resize_frames(frames, tuple(size)) # T, H, W, 3
80
  # frames and binary_masks are numpy arrays
81
-
82
  h, w = frames.shape[1:3]
83
  video_length = T
84
 
@@ -156,7 +156,7 @@ if __name__ == '__main__':
156
  base_inpainter = BaseInpainter(checkpoint, device)
157
  # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
158
  # ratio: (0, 1], ratio for down sample, default value is 1
159
- inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=1) # numpy array, T, H, W, 3
160
  # ----------------------------------------------
161
  # end
162
  # ----------------------------------------------
 
64
  masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
65
 
66
  T, H, W = masks.shape
67
+ masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
68
  # size: (w, h)
69
  if ratio == 1:
70
  size = None
71
+ binary_masks = masks
72
  else:
73
  size = [int(W*ratio), int(H*ratio)]
74
+ size = [si+1 if si%2>0 else si for si in size] # only consider even values
75
+ # shortest side should be larger than 50
76
+ if min(size) < 50:
77
+ ratio = 50. / min(H, W)
78
+ size = [int(W*ratio), int(H*ratio)]
79
+ binary_masks = resize_masks(masks, tuple(size))
80
+ frames = resize_frames(frames, tuple(size)) # T, H, W, 3
 
81
  # frames and binary_masks are numpy arrays
 
82
  h, w = frames.shape[1:3]
83
  video_length = T
84
 
 
156
  base_inpainter = BaseInpainter(checkpoint, device)
157
  # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
158
  # ratio: (0, 1], ratio for down sample, default value is 1
159
+ inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
160
  # ----------------------------------------------
161
  # end
162
  # ----------------------------------------------