watchtowerss commited on
Commit
508b599
1 Parent(s): 5fd7b77

RAM and VRAM usage reduce

Browse files
app.py CHANGED
@@ -341,7 +341,6 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
341
  operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
342
  inpainted_frames = video_state["origin_images"]
343
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
344
-
345
  return video_output, operation_log
346
 
347
 
@@ -423,7 +422,7 @@ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
423
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
424
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
425
  # args.port = 12213
426
- # args.device = "cuda:1"
427
  # args.mask_save = True
428
 
429
  # initialize sam, xmem, e2fgvi models
@@ -432,7 +431,7 @@ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args
432
 
433
  title = """<p><h1 align="center">Track-Anything</h1></p>
434
  """
435
- description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
436
 
437
 
438
  with gr.Blocks() as iface:
@@ -450,7 +449,7 @@ with gr.Blocks() as iface:
450
  "masks": []
451
  },
452
  "track_end_number": None,
453
- "resize_ratio": 1
454
  }
455
  )
456
 
@@ -470,48 +469,78 @@ with gr.Blocks() as iface:
470
  gr.Markdown(title)
471
  gr.Markdown(description)
472
  with gr.Row():
473
-
474
- # for user video input
475
  with gr.Column():
476
- with gr.Row(scale=0.4):
477
- video_input = gr.Video(autosize=True)
478
- with gr.Column():
479
- video_info = gr.Textbox(label="Video Info")
480
- resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
481
- Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
482
- resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
483
-
484
-
485
- with gr.Row():
486
- # put the template frame under the radio button
487
  with gr.Column():
488
- # extract frames
489
- with gr.Column():
490
- extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
 
 
 
 
 
491
 
492
- # click points settins, negative or positive, mode continuous or single
493
  with gr.Row():
494
- with gr.Row():
495
- point_prompt = gr.Radio(
496
- choices=["Positive", "Negative"],
497
- value="Positive",
498
- label="Point Prompt",
499
- interactive=True,
500
- visible=False)
501
- remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
502
- clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
503
- Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
504
- template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
505
- image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
506
- track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
507
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  with gr.Column():
509
- run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=True)
510
- mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
511
- video_output = gr.Video(autosize=True, visible=False).style(height=360)
512
- with gr.Row():
513
- tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
514
- inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
 
 
 
 
 
 
 
 
515
 
516
  # first step: get the video information
517
  extract_frames_button.click(
@@ -601,7 +630,7 @@ with gr.Blocks() as iface:
601
  "masks": []
602
  },
603
  "track_end_number": 0,
604
- "resize_ratio": 1
605
  },
606
  [[],[]],
607
  None,
@@ -609,7 +638,7 @@ with gr.Blocks() as iface:
609
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
610
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
611
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
612
- gr.update(visible=False), gr.update(visible=False)
613
 
614
  ),
615
  [],
@@ -631,18 +660,6 @@ with gr.Blocks() as iface:
631
  inputs = [video_state, click_state,],
632
  outputs = [template_frame,click_state, run_status],
633
  )
634
- # set example
635
- gr.Markdown("## Examples")
636
- gr.Examples(
637
- examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
638
- "test-sample2.mp4","test-sample13.mp4"]],
639
- fn=run_example,
640
- inputs=[
641
- video_input
642
- ],
643
- outputs=[video_input],
644
- # cache_examples=True,
645
- )
646
  iface.queue(concurrency_count=1)
647
  # iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
648
  iface.launch(debug=True, enable_queue=True)
341
  operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
342
  inpainted_frames = video_state["origin_images"]
343
  video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
 
344
  return video_output, operation_log
345
 
346
 
422
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
423
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
424
  # args.port = 12213
425
+ # args.device = "cuda:8"
426
  # args.mask_save = True
427
 
428
  # initialize sam, xmem, e2fgvi models
431
 
432
  title = """<p><h1 align="center">Track-Anything</h1></p>
433
  """
434
+ description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a> If you stuck in unknown errors, please feel free to watch the Tutorial video.</p>"""
435
 
436
 
437
  with gr.Blocks() as iface:
449
  "masks": []
450
  },
451
  "track_end_number": None,
452
+ "resize_ratio": 0.6
453
  }
454
  )
455
 
469
  gr.Markdown(title)
470
  gr.Markdown(description)
471
  with gr.Row():
 
 
472
  with gr.Column():
473
+ with gr.Tab("Test"):
474
+ # for user video input
 
 
 
 
 
 
 
 
 
475
  with gr.Column():
476
+ with gr.Row(scale=0.4):
477
+ video_input = gr.Video(autosize=True)
478
+ with gr.Column():
479
+ video_info = gr.Textbox(label="Video Info")
480
+ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
481
+ Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
482
+ resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=0.6, label="Resize ratio", visible=True)
483
+
484
 
 
485
  with gr.Row():
486
+ # put the template frame under the radio button
487
+ with gr.Column():
488
+ # extract frames
489
+ with gr.Column():
490
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
491
+
492
+ # click points settins, negative or positive, mode continuous or single
493
+ with gr.Row():
494
+ with gr.Row():
495
+ point_prompt = gr.Radio(
496
+ choices=["Positive", "Negative"],
497
+ value="Positive",
498
+ label="Point Prompt",
499
+ interactive=True,
500
+ visible=False)
501
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
502
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
503
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
504
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
505
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
506
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
507
+
508
+ with gr.Column():
509
+ run_status = gr.HighlightedText(value=[("Run","Error"),("Status","Normal")], visible=True)
510
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
511
+ video_output = gr.Video(autosize=True, visible=False).style(height=360)
512
+ with gr.Row():
513
+ tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
514
+ inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
515
+ # set example
516
+ gr.Markdown("## Examples")
517
+ gr.Examples(
518
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
519
+ "test-sample2.mp4","test-sample13.mp4"]],
520
+ fn=run_example,
521
+ inputs=[
522
+ video_input
523
+ ],
524
+ outputs=[video_input],
525
+ # cache_examples=True,
526
+ )
527
+
528
+ with gr.Tab("Tutorial"):
529
  with gr.Column():
530
+ with gr.Row(scale=0.4):
531
+ video_demo_operation = gr.Video(autosize=True)
532
+
533
+ # set example
534
+ gr.Markdown("## Operation tutorial video")
535
+ gr.Examples(
536
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["huggingface_demo_operation.mp4"]],
537
+ fn=run_example,
538
+ inputs=[
539
+ video_demo_operation
540
+ ],
541
+ outputs=[video_demo_operation],
542
+ # cache_examples=True,
543
+ )
544
 
545
  # first step: get the video information
546
  extract_frames_button.click(
630
  "masks": []
631
  },
632
  "track_end_number": 0,
633
+ "resize_ratio": 0.6
634
  },
635
  [[],[]],
636
  None,
638
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
639
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
640
  gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
641
+ gr.update(visible=False), gr.update(visible=True)
642
 
643
  ),
644
  [],
660
  inputs = [video_state, click_state,],
661
  outputs = [template_frame,click_state, run_status],
662
  )
 
 
 
 
 
 
 
 
 
 
 
 
663
  iface.queue(concurrency_count=1)
664
  # iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
665
  iface.launch(debug=True, enable_queue=True)
inpainter/base_inpainter.py CHANGED
@@ -9,9 +9,9 @@ import numpy as np
9
  from tqdm import tqdm
10
  from inpainter.util.tensor_util import resize_frames, resize_masks
11
 
12
- def read_image_from_userfolder(image_path):
13
  # if type:
14
- image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
15
  # else:
16
  # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
17
  return image
@@ -56,29 +56,30 @@ class BaseInpainter:
56
  break
57
  ref_index.append(i)
58
  return ref_index
59
-
60
- def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
61
  """
 
62
  frames: numpy array, T, H, W, 3
63
  masks: numpy array, T, H, W
 
 
64
  dilate_radius: radius when applying dilation on masks
65
  ratio: down-sample ratio
66
 
67
  Output:
68
  inpainted_frames: numpy array, T, H, W, 3
69
  """
70
- frames = []
71
- for file in frames_path:
72
- frames.append(read_image_from_userfolder(file))
73
- frames = np.asarray(frames)
74
-
75
  assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
76
  assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
 
 
 
 
77
  masks = masks.copy()
78
  masks = np.clip(masks, 0, 1)
79
  kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
80
  masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
81
-
82
  T, H, W = masks.shape
83
  masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
84
  # size: (w, h)
@@ -96,14 +97,37 @@ class BaseInpainter:
96
  frames = resize_frames(frames, tuple(size)) # T, H, W, 3
97
  # frames and binary_masks are numpy arrays
98
  h, w = frames.shape[1:3]
99
- video_length = T
100
-
101
  # convert to tensor
102
  imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
103
  masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
104
-
105
  imgs, masks = imgs.to(self.device), masks.to(self.device)
106
  comp_frames = [None] * video_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
109
  neighbor_ids = [
@@ -111,8 +135,24 @@ class BaseInpainter:
111
  min(video_length, f + self.neighbor_stride + 1))
112
  ]
113
  ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
114
- selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
115
- selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with torch.no_grad():
117
  masked_imgs = selected_imgs * (1 - selected_masks)
118
  mod_size_h = 60
@@ -138,10 +178,75 @@ class BaseInpainter:
138
  else:
139
  comp_frames[idx] = comp_frames[idx].astype(
140
  np.float32) * 0.5 + img.astype(np.float32) * 0.5
141
-
142
  inpainted_frames = np.stack(comp_frames, 0)
143
  return inpainted_frames.astype(np.uint8)
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if __name__ == '__main__':
146
 
147
  frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
@@ -179,4 +284,4 @@ if __name__ == '__main__':
179
  # save
180
  for ti, inpainted_frame in enumerate(inpainted_frames):
181
  frame = Image.fromarray(inpainted_frame).convert('RGB')
182
- frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
9
  from tqdm import tqdm
10
  from inpainter.util.tensor_util import resize_frames, resize_masks
11
 
12
+ def read_image_from_split(videp_split_path):
13
  # if type:
14
+ image = np.asarray([np.asarray(Image.open(path)) for path in videp_split_path])
15
  # else:
16
  # image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
17
  return image
56
  break
57
  ref_index.append(i)
58
  return ref_index
59
+
60
+ def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, ratio=1):
61
  """
62
+ Perform Inpainting for video subsets
63
  frames: numpy array, T, H, W, 3
64
  masks: numpy array, T, H, W
65
+ num_tcb: constant, number of temporal context before, frames
66
+ num_tca: constant, number of temporal context after, frames
67
  dilate_radius: radius when applying dilation on masks
68
  ratio: down-sample ratio
69
 
70
  Output:
71
  inpainted_frames: numpy array, T, H, W, 3
72
  """
 
 
 
 
 
73
  assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
74
  assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
75
+
76
+ # --------------------
77
+ # pre-processing
78
+ # --------------------
79
  masks = masks.copy()
80
  masks = np.clip(masks, 0, 1)
81
  kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
82
  masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
 
83
  T, H, W = masks.shape
84
  masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
85
  # size: (w, h)
97
  frames = resize_frames(frames, tuple(size)) # T, H, W, 3
98
  # frames and binary_masks are numpy arrays
99
  h, w = frames.shape[1:3]
100
+ video_length = T - (num_tca + num_tcb) # real video length
 
101
  # convert to tensor
102
  imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
103
  masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
 
104
  imgs, masks = imgs.to(self.device), masks.to(self.device)
105
  comp_frames = [None] * video_length
106
+ tcb_imgs = None
107
+ tca_imgs = None
108
+ tcb_masks = None
109
+ tca_masks = None
110
+ # --------------------
111
+ # end of pre-processing
112
+ # --------------------
113
+
114
+ # separate tc frames/masks from imgs and masks
115
+ if num_tcb > 0:
116
+ tcb_imgs = imgs[:, :num_tcb]
117
+ tcb_masks = masks[:, :num_tcb]
118
+ tcb_binary = binary_masks[:num_tcb]
119
+ if num_tca > 0:
120
+ tca_imgs = imgs[:, -num_tca:]
121
+ tca_masks = masks[:, -num_tca:]
122
+ tca_binary = binary_masks[-num_tca:]
123
+ end_idx = -num_tca
124
+ else:
125
+ end_idx = T
126
+
127
+ imgs = imgs[:, num_tcb:end_idx]
128
+ masks = masks[:, num_tcb:end_idx]
129
+ binary_masks = binary_masks[num_tcb:end_idx] # only neighbor area are involved
130
+ frames = frames[num_tcb:end_idx] # only neighbor area are involved
131
 
132
  for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
133
  neighbor_ids = [
135
  min(video_length, f + self.neighbor_stride + 1))
136
  ]
137
  ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
138
+
139
+ # selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
140
+ # selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
141
+
142
+ selected_imgs = imgs[:, neighbor_ids]
143
+ selected_masks = masks[:, neighbor_ids]
144
+ # pad before
145
+ if tcb_imgs is not None:
146
+ selected_imgs = torch.concat([selected_imgs, tcb_imgs], dim=1)
147
+ selected_masks = torch.concat([selected_masks, tcb_masks], dim=1)
148
+ # integrate ref frames
149
+ selected_imgs = torch.concat([selected_imgs, imgs[:, ref_ids]], dim=1)
150
+ selected_masks = torch.concat([selected_masks, masks[:, ref_ids]], dim=1)
151
+ # pad after
152
+ if tca_imgs is not None:
153
+ selected_imgs = torch.concat([selected_imgs, tca_imgs], dim=1)
154
+ selected_masks = torch.concat([selected_masks, tca_masks], dim=1)
155
+
156
  with torch.no_grad():
157
  masked_imgs = selected_imgs * (1 - selected_masks)
158
  mod_size_h = 60
178
  else:
179
  comp_frames[idx] = comp_frames[idx].astype(
180
  np.float32) * 0.5 + img.astype(np.float32) * 0.5
181
+ torch.cuda.empty_cache()
182
  inpainted_frames = np.stack(comp_frames, 0)
183
  return inpainted_frames.astype(np.uint8)
184
 
185
+ def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
186
+ """
187
+ Perform Inpainting for video subsets
188
+ frames: numpy array, T, H, W, 3
189
+ masks: numpy array, T, H, W
190
+ dilate_radius: radius when applying dilation on masks
191
+ ratio: down-sample ratio
192
+
193
+ Output:
194
+ inpainted_frames: numpy array, T, H, W, 3
195
+ """
196
+ # assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
197
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
198
+
199
+ # set interval
200
+ interval = 45
201
+ context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames
202
+ # split frames into subsets
203
+ video_length = len(frames_path)
204
+ num_splits = video_length // interval
205
+ id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits
206
+ # if remaining split > interval/2, add a new split, else, append to the last split
207
+ if video_length - id_splits[-1][-1] > interval / 2:
208
+ id_splits.append([num_splits*interval, video_length])
209
+ else:
210
+ id_splits[-1][-1] = video_length
211
+
212
+ # perform inpainting for each split
213
+ inpainted_splits = []
214
+ for id_split in id_splits:
215
+ video_split_path = frames_path[id_split[0]:id_split[1]]
216
+ video_split = read_image_from_split(video_split_path)
217
+ mask_split = masks[id_split[0]:id_split[1]]
218
+
219
+ # | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
220
+ # add temporal context
221
+ id_before = max(0, id_split[0] - self.step * context_range)
222
+ try:
223
+ tcb_frames = np.stack([np.array(Image.open(frames_path[idb])) for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
224
+ tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
225
+ num_tcb = len(tcb_frames)
226
+ except:
227
+ num_tcb = 0
228
+ id_after = min(video_length, id_split[1] + self.step * context_range)
229
+ try:
230
+ tca_frames = np.stack([np.array(Image.open(frames_path[ida])) for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
231
+ tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
232
+ num_tca = len(tca_frames)
233
+ except:
234
+ num_tca = 0
235
+
236
+ # concatenate temporal context frames/masks with input frames/masks (for parallel pre-processing)
237
+ if num_tcb > 0:
238
+ video_split = np.concatenate([tcb_frames, video_split], 0)
239
+ mask_split = np.concatenate([tcb_masks, mask_split], 0)
240
+ if num_tca > 0:
241
+ video_split = np.concatenate([video_split, tca_frames], 0)
242
+ mask_split = np.concatenate([mask_split, tca_masks], 0)
243
+
244
+ # inpaint each split
245
+ inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio))
246
+
247
+ inpainted_frames = np.concatenate(inpainted_splits, 0)
248
+ return inpainted_frames.astype(np.uint8)
249
+
250
  if __name__ == '__main__':
251
 
252
  frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
284
  # save
285
  for ti, inpainted_frame in enumerate(inpainted_frames):
286
  frame = Image.fromarray(inpainted_frame).convert('RGB')
287
+ frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
test_sample/test-sample13.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf112202beb75ecf7d04b27758f1f3eedfc218dac5d5dad0b72a07dd2db0f423
3
- size 59659465
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54f24a0aaae482aff7ff3555256f60ad1931d478dc5694fb37624cac85479eee
3
+ size 2528426
test_sample/test-sample4.mp4 CHANGED
Binary files a/test_sample/test-sample4.mp4 and b/test_sample/test-sample4.mp4 differ
test_sample/test-sample8.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:01d255ef82222d950d2cfae904d82cc20c752577016f0325b21788fb9b458bb9
3
- size 11979994
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2414d24cc1ddfe1619c17e9876a7c3ed0f1f37da234c63c08af2cecbbb16c1ed
3
+ size 8714250