watchtowerss commited on
Commit
4d1ebf3
1 Parent(s): 663e9a6

track-anything --version 1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE +21 -0
  3. README.md +47 -13
  4. XMem-s012.pth +3 -0
  5. app.py +362 -0
  6. app_save.py +381 -0
  7. app_test.py +23 -0
  8. assets/demo_version_1.MP4 +3 -0
  9. assets/inpainting.gif +3 -0
  10. assets/poster_demo_version_1.png +0 -0
  11. assets/qingming.mp4 +3 -0
  12. demo.py +87 -0
  13. images/groceries.jpg +0 -0
  14. images/mask_painter.png +0 -0
  15. images/painter_input_image.jpg +0 -0
  16. images/painter_input_mask.jpg +0 -0
  17. images/painter_output_image.png +0 -0
  18. images/painter_output_image__.png +0 -0
  19. images/point_painter.png +0 -0
  20. images/point_painter_1.png +0 -0
  21. images/point_painter_2.png +0 -0
  22. images/truck.jpg +0 -0
  23. images/truck_both.jpg +0 -0
  24. images/truck_mask.jpg +0 -0
  25. images/truck_point.jpg +0 -0
  26. inpainter/.DS_Store +0 -0
  27. inpainter/base_inpainter.py +160 -0
  28. inpainter/config/config.yaml +4 -0
  29. inpainter/model/e2fgvi.py +350 -0
  30. inpainter/model/e2fgvi_hq.py +350 -0
  31. inpainter/model/modules/feat_prop.py +149 -0
  32. inpainter/model/modules/flow_comp.py +450 -0
  33. inpainter/model/modules/spectral_norm.py +288 -0
  34. inpainter/model/modules/tfocal_transformer.py +536 -0
  35. inpainter/model/modules/tfocal_transformer_hq.py +565 -0
  36. inpainter/util/__init__.py +0 -0
  37. inpainter/util/tensor_util.py +24 -0
  38. requirements.txt +17 -0
  39. sam_vit_h_4b8939.pth +3 -0
  40. template.html +27 -0
  41. templates/index.html +50 -0
  42. text_server.py +72 -0
  43. tools/__init__.py +0 -0
  44. tools/base_segmenter.py +129 -0
  45. tools/interact_tools.py +265 -0
  46. tools/mask_painter.py +288 -0
  47. tools/painter.py +215 -0
  48. track_anything.py +93 -0
  49. tracker/.DS_Store +0 -0
  50. tracker/base_tracker.py +233 -0
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
36
+ assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Mingqi Gao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,47 @@
1
- ---
2
- title: Track Anything
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.27.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Track-Anything
2
+
3
+ ***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
4
+ - Video object tracking and segmentation with shot changes.
5
+ - Data annnotation for video object tracking and segmentation.
6
+ - Object-centric downstream video tasks, such as video inpainting and editing.
7
+
8
+ ## Demo
9
+
10
+ https://user-images.githubusercontent.com/28050374/232842703-8395af24-b13e-4b8e-aafb-e94b61e6c449.MP4
11
+
12
+ ### Multiple Object Tracking and Segmentation (with [XMem](https://github.com/hkchengrex/XMem))
13
+
14
+ https://user-images.githubusercontent.com/39208339/233035206-0a151004-6461-4deb-b782-d1dbfe691493.mp4
15
+
16
+ ### Video Object Tracking and Segmentation with Shot Changes (with [XMem](https://github.com/hkchengrex/XMem))
17
+
18
+ https://user-images.githubusercontent.com/30309970/232848349-f5e29e71-2ea4-4529-ac9a-94b9ca1e7055.mp4
19
+
20
+ ### Video Inpainting (with [E2FGVI](https://github.com/MCG-NKU/E2FGVI))
21
+
22
+ https://user-images.githubusercontent.com/28050374/232959816-07f2826f-d267-4dda-8ae5-a5132173b8f4.mp4
23
+
24
+ ## Get Started
25
+ #### Linux
26
+ ```bash
27
+ # Clone the repository:
28
+ git clone https://github.com/gaomingqi/Track-Anything.git
29
+ cd Track-Anything
30
+
31
+ # Install dependencies:
32
+ pip install -r requirements.txt
33
+
34
+ # Install dependencies for inpainting:
35
+ pip install -U openmim
36
+ mim install mmcv
37
+
38
+ # Install dependencies for editing
39
+ pip install madgrad
40
+
41
+ # Run the Track-Anything gradio demo.
42
+ python app.py --device cuda:0 --sam_model_type vit_h --port 12212
43
+ ```
44
+
45
+ ## Acknowledgements
46
+
47
+ The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts.
XMem-s012.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16205ad04bfc55b442bd4d7af894382e09868b35e10721c5afc09a24ea8d72d9
3
+ size 249026057
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
3
+ import argparse
4
+ import cv2
5
+ import time
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ import sys
10
+ sys.path.append(sys.path[0]+"/tracker")
11
+ sys.path.append(sys.path[0]+"/tracker/model")
12
+ from track_anything import TrackingAnything
13
+ from track_anything import parse_augment
14
+ import requests
15
+ import json
16
+ import torchvision
17
+ import torch
18
+ import concurrent.futures
19
+ import queue
20
+
21
+ # download checkpoints
22
+ def download_checkpoint(url, folder, filename):
23
+ os.makedirs(folder, exist_ok=True)
24
+ filepath = os.path.join(folder, filename)
25
+
26
+ if not os.path.exists(filepath):
27
+ print("download checkpoints ......")
28
+ response = requests.get(url, stream=True)
29
+ with open(filepath, "wb") as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
+ if chunk:
32
+ f.write(chunk)
33
+
34
+ print("download successfully!")
35
+
36
+ return filepath
37
+
38
+ # convert points input to prompt state
39
+ def get_prompt(click_state, click_input):
40
+ inputs = json.loads(click_input)
41
+ points = click_state[0]
42
+ labels = click_state[1]
43
+ for input in inputs:
44
+ points.append(input[:2])
45
+ labels.append(input[2])
46
+ click_state[0] = points
47
+ click_state[1] = labels
48
+ prompt = {
49
+ "prompt_type":["click"],
50
+ "input_point":click_state[0],
51
+ "input_label":click_state[1],
52
+ "multimask_output":"True",
53
+ }
54
+ return prompt
55
+
56
+ # extract frames from upload video
57
+ def get_frames_from_video(video_input, video_state):
58
+ """
59
+ Args:
60
+ video_path:str
61
+ timestamp:float64
62
+ Return
63
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
64
+ """
65
+ video_path = video_input
66
+ frames = []
67
+ try:
68
+ cap = cv2.VideoCapture(video_path)
69
+ fps = cap.get(cv2.CAP_PROP_FPS)
70
+ while cap.isOpened():
71
+ ret, frame = cap.read()
72
+ if ret == True:
73
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
74
+ else:
75
+ break
76
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
77
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
78
+
79
+ # initialize video_state
80
+ video_state = {
81
+ "video_name": os.path.split(video_path)[-1],
82
+ "origin_images": frames,
83
+ "painted_images": frames.copy(),
84
+ "masks": [None]*len(frames),
85
+ "logits": [None]*len(frames),
86
+ "select_frame_number": 0,
87
+ "fps": 30
88
+ }
89
+ return video_state, gr.update(visible=True, maximum=len(frames), value=1)
90
+
91
+ # get the select frame from gradio slider
92
+ def select_template(image_selection_slider, video_state):
93
+
94
+ # images = video_state[1]
95
+ image_selection_slider -= 1
96
+ video_state["select_frame_number"] = image_selection_slider
97
+
98
+ # once select a new template frame, set the image in sam
99
+
100
+ model.samcontroler.sam_controler.reset_image()
101
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
102
+
103
+
104
+ return video_state["painted_images"][image_selection_slider], video_state
105
+
106
+ # use sam to get the mask
107
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
108
+ """
109
+ Args:
110
+ template_frame: PIL.Image
111
+ point_prompt: flag for positive or negative button click
112
+ click_state: [[points], [labels]]
113
+ """
114
+ if point_prompt == "Positive":
115
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
116
+ interactive_state["positive_click_times"] += 1
117
+ else:
118
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
119
+ interactive_state["negative_click_times"] += 1
120
+
121
+ # prompt for sam model
122
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
123
+
124
+ mask, logit, painted_image = model.first_frame_click(
125
+ image=video_state["origin_images"][video_state["select_frame_number"]],
126
+ points=np.array(prompt["input_point"]),
127
+ labels=np.array(prompt["input_label"]),
128
+ multimask=prompt["multimask_output"],
129
+ )
130
+ video_state["masks"][video_state["select_frame_number"]] = mask
131
+ video_state["logits"][video_state["select_frame_number"]] = logit
132
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
133
+
134
+ return painted_image, video_state, interactive_state
135
+
136
+ # tracking vos
137
+ def vos_tracking_video(video_state, interactive_state):
138
+ model.xmem.clear_memory()
139
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
140
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
141
+ fps = video_state["fps"]
142
+ masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
143
+
144
+ video_state["masks"][video_state["select_frame_number"]:] = masks
145
+ video_state["logits"][video_state["select_frame_number"]:] = logits
146
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
147
+
148
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
149
+ interactive_state["inference_times"] += 1
150
+
151
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
152
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
153
+ interactive_state["positive_click_times"],
154
+ interactive_state["negative_click_times"]))
155
+
156
+ #### shanggao code for mask save
157
+ if interactive_state["mask_save"]:
158
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
159
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
160
+ i = 0
161
+ print("save mask")
162
+ for mask in video_state["masks"]:
163
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
164
+ i+=1
165
+ # save_mask(video_state["masks"], video_state["video_name"])
166
+ #### shanggao code for mask save
167
+ return video_output, video_state, interactive_state
168
+
169
+ # generate video after vos inference
170
+ def generate_video_from_frames(frames, output_path, fps=30):
171
+ """
172
+ Generates a video from a list of frames.
173
+
174
+ Args:
175
+ frames (list of numpy arrays): The frames to include in the video.
176
+ output_path (str): The path to save the generated video.
177
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
178
+ """
179
+ frames = torch.from_numpy(np.asarray(frames))
180
+ if not os.path.exists(os.path.dirname(output_path)):
181
+ os.makedirs(os.path.dirname(output_path))
182
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
183
+ return output_path
184
+
185
+ # check and download checkpoints if needed
186
+ SAM_checkpoint = "sam_vit_h_4b8939.pth"
187
+ sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
188
+ xmem_checkpoint = "XMem-s012.pth"
189
+ xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
190
+ folder ="./checkpoints"
191
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
192
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
193
+
194
+ # args, defined in track_anything.py
195
+ args = parse_augment()
196
+ # args.port = 12212
197
+ # args.device = "cuda:4"
198
+ # args.mask_save = True
199
+
200
+ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
201
+
202
+ with gr.Blocks() as iface:
203
+ """
204
+ state for
205
+ """
206
+ click_state = gr.State([[],[]])
207
+ interactive_state = gr.State({
208
+ "inference_times": 0,
209
+ "negative_click_times" : 0,
210
+ "positive_click_times": 0,
211
+ "mask_save": args.mask_save
212
+ })
213
+ video_state = gr.State(
214
+ {
215
+ "video_name": "",
216
+ "origin_images": None,
217
+ "painted_images": None,
218
+ "masks": None,
219
+ "logits": None,
220
+ "select_frame_number": 0,
221
+ "fps": 30
222
+ }
223
+ )
224
+
225
+ with gr.Row():
226
+
227
+ # for user video input
228
+ with gr.Column(scale=1.0):
229
+ video_input = gr.Video().style(height=360)
230
+
231
+
232
+
233
+ with gr.Row(scale=1):
234
+ # put the template frame under the radio button
235
+ with gr.Column(scale=0.5):
236
+ # extract frames
237
+ with gr.Column():
238
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
239
+
240
+ # click points settins, negative or positive, mode continuous or single
241
+ with gr.Row():
242
+ with gr.Row(scale=0.5):
243
+ point_prompt = gr.Radio(
244
+ choices=["Positive", "Negative"],
245
+ value="Positive",
246
+ label="Point Prompt",
247
+ interactive=True)
248
+ click_mode = gr.Radio(
249
+ choices=["Continuous", "Single"],
250
+ value="Continuous",
251
+ label="Clicking Mode",
252
+ interactive=True)
253
+ with gr.Row(scale=0.5):
254
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
255
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
256
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
257
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False)
258
+
259
+
260
+
261
+
262
+ with gr.Column(scale=0.5):
263
+ video_output = gr.Video().style(height=360)
264
+ tracking_video_predict_button = gr.Button(value="Tracking")
265
+
266
+ # first step: get the video information
267
+ extract_frames_button.click(
268
+ fn=get_frames_from_video,
269
+ inputs=[
270
+ video_input, video_state
271
+ ],
272
+ outputs=[video_state, image_selection_slider],
273
+ )
274
+
275
+ # second step: select images from slider
276
+ image_selection_slider.release(fn=select_template,
277
+ inputs=[image_selection_slider, video_state],
278
+ outputs=[template_frame, video_state], api_name="select_image")
279
+
280
+
281
+ template_frame.select(
282
+ fn=sam_refine,
283
+ inputs=[video_state, point_prompt, click_state, interactive_state],
284
+ outputs=[template_frame, video_state, interactive_state]
285
+ )
286
+
287
+ tracking_video_predict_button.click(
288
+ fn=vos_tracking_video,
289
+ inputs=[video_state, interactive_state],
290
+ outputs=[video_output, video_state, interactive_state]
291
+ )
292
+
293
+
294
+ # clear input
295
+ video_input.clear(
296
+ lambda: (
297
+ {
298
+ "origin_images": None,
299
+ "painted_images": None,
300
+ "masks": None,
301
+ "logits": None,
302
+ "select_frame_number": 0,
303
+ "fps": 30
304
+ },
305
+ {
306
+ "inference_times": 0,
307
+ "negative_click_times" : 0,
308
+ "positive_click_times": 0,
309
+ "mask_save": args.mask_save
310
+ },
311
+ [[],[]]
312
+ ),
313
+ [],
314
+ [
315
+ video_state,
316
+ interactive_state,
317
+ click_state,
318
+ ],
319
+ queue=False,
320
+ show_progress=False
321
+ )
322
+ clear_button_image.click(
323
+ lambda: (
324
+ {
325
+ "origin_images": None,
326
+ "painted_images": None,
327
+ "masks": None,
328
+ "logits": None,
329
+ "select_frame_number": 0,
330
+ "fps": 30
331
+ },
332
+ {
333
+ "inference_times": 0,
334
+ "negative_click_times" : 0,
335
+ "positive_click_times": 0,
336
+ "mask_save": args.mask_save
337
+ },
338
+ [[],[]]
339
+ ),
340
+ [],
341
+ [
342
+ video_state,
343
+ interactive_state,
344
+ click_state,
345
+ ],
346
+
347
+ queue=False,
348
+ show_progress=False
349
+
350
+ )
351
+ clear_button_clike.click(
352
+ lambda: ([[],[]]),
353
+ [],
354
+ [click_state],
355
+ queue=False,
356
+ show_progress=False
357
+ )
358
+ iface.queue(concurrency_count=1)
359
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
360
+
361
+
362
+
app_save.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
3
+ import argparse
4
+ import cv2
5
+ import time
6
+ from PIL import Image
7
+ import numpy as np
8
+ import os
9
+ import sys
10
+ sys.path.append(sys.path[0]+"/tracker")
11
+ sys.path.append(sys.path[0]+"/tracker/model")
12
+ from track_anything import TrackingAnything
13
+ from track_anything import parse_augment
14
+ import requests
15
+ import json
16
+ import torchvision
17
+ import torch
18
+ import concurrent.futures
19
+ import queue
20
+
21
+ def download_checkpoint(url, folder, filename):
22
+ os.makedirs(folder, exist_ok=True)
23
+ filepath = os.path.join(folder, filename)
24
+
25
+ if not os.path.exists(filepath):
26
+ print("download checkpoints ......")
27
+ response = requests.get(url, stream=True)
28
+ with open(filepath, "wb") as f:
29
+ for chunk in response.iter_content(chunk_size=8192):
30
+ if chunk:
31
+ f.write(chunk)
32
+
33
+ print("download successfully!")
34
+
35
+ return filepath
36
+
37
+ def pause_video(play_state):
38
+ print("user pause_video")
39
+ play_state.append(time.time())
40
+ return play_state
41
+
42
+ def play_video(play_state):
43
+ print("user play_video")
44
+ play_state.append(time.time())
45
+ return play_state
46
+
47
+ # convert points input to prompt state
48
+ def get_prompt(click_state, click_input):
49
+ inputs = json.loads(click_input)
50
+ points = click_state[0]
51
+ labels = click_state[1]
52
+ for input in inputs:
53
+ points.append(input[:2])
54
+ labels.append(input[2])
55
+ click_state[0] = points
56
+ click_state[1] = labels
57
+ prompt = {
58
+ "prompt_type":["click"],
59
+ "input_point":click_state[0],
60
+ "input_label":click_state[1],
61
+ "multimask_output":"True",
62
+ }
63
+ return prompt
64
+
65
+ def get_frames_from_video(video_input, play_state):
66
+ """
67
+ Args:
68
+ video_path:str
69
+ timestamp:float64
70
+ Return
71
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
72
+ """
73
+ video_path = video_input
74
+ # video_name = video_path.split('/')[-1]
75
+
76
+ try:
77
+ timestamp = play_state[1] - play_state[0]
78
+ except:
79
+ timestamp = 0
80
+ frames = []
81
+ try:
82
+ cap = cv2.VideoCapture(video_path)
83
+ fps = cap.get(cv2.CAP_PROP_FPS)
84
+ while cap.isOpened():
85
+ ret, frame = cap.read()
86
+ if ret == True:
87
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
88
+ else:
89
+ break
90
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
91
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
92
+
93
+ # for index, frame in enumerate(frames):
94
+ # frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
95
+
96
+ key_frame_index = int(timestamp * fps)
97
+ nearest_frame = frames[key_frame_index]
98
+ frames_split = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
99
+ # output_path='./seperate.mp4'
100
+ # torchvision.io.write_video(output_path, frames[1], fps=fps, video_codec="libx264")
101
+
102
+ # set image in sam when select the template frame
103
+ model.samcontroler.sam_controler.set_image(nearest_frame)
104
+ return frames_split, nearest_frame, nearest_frame, fps
105
+
106
+ def generate_video_from_frames(frames, output_path, fps=30):
107
+ """
108
+ Generates a video from a list of frames.
109
+
110
+ Args:
111
+ frames (list of numpy arrays): The frames to include in the video.
112
+ output_path (str): The path to save the generated video.
113
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
114
+ """
115
+ # height, width, layers = frames[0].shape
116
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
117
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
118
+
119
+ # for frame in frames:
120
+ # video.write(frame)
121
+
122
+ # video.release()
123
+ frames = torch.from_numpy(np.asarray(frames))
124
+ output_path='./output.mp4'
125
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
126
+ return output_path
127
+
128
+ def model_reset():
129
+ model.xmem.clear_memory()
130
+ return None
131
+
132
+ def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
133
+ """
134
+ Args:
135
+ template_frame: PIL.Image
136
+ point_prompt: flag for positive or negative button click
137
+ click_state: [[points], [labels]]
138
+ """
139
+ if point_prompt == "Positive":
140
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
141
+ else:
142
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
143
+
144
+ # prompt for sam model
145
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
146
+
147
+ # default value
148
+ # points = np.array([[evt.index[0],evt.index[1]]])
149
+ # labels= np.array([1])
150
+ if len(logit)==0:
151
+ logit = None
152
+
153
+ mask, logit, painted_image = model.first_frame_click(
154
+ image=origin_frame,
155
+ points=np.array(prompt["input_point"]),
156
+ labels=np.array(prompt["input_label"]),
157
+ multimask=prompt["multimask_output"],
158
+ )
159
+ return painted_image, click_state, logit, mask
160
+
161
+
162
+
163
+ def vos_tracking_video(video_state, template_mask,fps,video_input):
164
+
165
+ masks, logits, painted_images = model.generator(images=video_state[1], template_mask=template_mask)
166
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
167
+ # image_selection_slider = gr.Slider(minimum=1, maximum=len(video_state[1]), value=1, label="Image Selection", interactive=True)
168
+ video_name = video_input.split('/')[-1].split('.')[0]
169
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
170
+ if not os.path.exists(result_path):
171
+ os.makedirs(result_path)
172
+ i=0
173
+ for mask in masks:
174
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
175
+ i+=1
176
+ return video_output, painted_images, masks, logits
177
+
178
+ def vos_tracking_image(image_selection_slider, painted_images):
179
+
180
+ # images = video_state[1]
181
+ percentage = image_selection_slider / 100
182
+ select_frame_num = int(percentage * len(painted_images))
183
+ return painted_images[select_frame_num], select_frame_num
184
+
185
+ def interactive_correction(video_state, point_prompt, click_state, select_correction_frame, evt: gr.SelectData):
186
+ """
187
+ Args:
188
+ template_frame: PIL.Image
189
+ point_prompt: flag for positive or negative button click
190
+ click_state: [[points], [labels]]
191
+ """
192
+ refine_image = video_state[1][select_correction_frame]
193
+ if point_prompt == "Positive":
194
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
195
+ else:
196
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
197
+
198
+ # prompt for sam model
199
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
200
+ model.samcontroler.seg_again(refine_image)
201
+ corrected_mask, corrected_logit, corrected_painted_image = model.first_frame_click(
202
+ image=refine_image,
203
+ points=np.array(prompt["input_point"]),
204
+ labels=np.array(prompt["input_label"]),
205
+ multimask=prompt["multimask_output"],
206
+ )
207
+ return corrected_painted_image, [corrected_mask, corrected_logit, corrected_painted_image]
208
+
209
+ def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps, video_input):
210
+ model.xmem.clear_memory()
211
+ # inference the following images
212
+ following_images = video_state[1][select_correction_frame:]
213
+ corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, template_mask=corrected_state[0])
214
+ masks = masks[:select_correction_frame] + corrected_masks
215
+ logits = logits[:select_correction_frame] + corrected_logits
216
+ painted_images = painted_images[:select_correction_frame] + corrected_painted_images
217
+ video_output = generate_video_from_frames(painted_images, output_path="./output.mp4", fps=fps)
218
+
219
+ video_name = video_input.split('/')[-1].split('.')[0]
220
+ result_path = os.path.join('/hhd3/gaoshang/Track-Anything/results/'+video_name)
221
+ if not os.path.exists(result_path):
222
+ os.makedirs(result_path)
223
+ i=0
224
+ for mask in masks:
225
+ np.save(os.path.join(result_path,'{:05}.npy'.format(i)), mask)
226
+ i+=1
227
+ return video_output, painted_images, logits, masks
228
+
229
+ # check and download checkpoints if needed
230
+ SAM_checkpoint = "sam_vit_h_4b8939.pth"
231
+ sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
232
+ xmem_checkpoint = "XMem-s012.pth"
233
+ xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
234
+ folder ="./checkpoints"
235
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
236
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
237
+
238
+ # args, defined in track_anything.py
239
+ args = parse_augment()
240
+ args.port = 12207
241
+ args.device = "cuda:5"
242
+
243
+ model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
244
+
245
+ with gr.Blocks() as iface:
246
+ """
247
+ state for
248
+ """
249
+ state = gr.State([])
250
+ play_state = gr.State([])
251
+ video_state = gr.State([[],[],[]])
252
+ click_state = gr.State([[],[]])
253
+ logits = gr.State([])
254
+ masks = gr.State([])
255
+ painted_images = gr.State([])
256
+ origin_image = gr.State(None)
257
+ template_mask = gr.State(None)
258
+ select_correction_frame = gr.State(None)
259
+ corrected_state = gr.State([[],[],[]])
260
+ fps = gr.State([])
261
+ # video_name = gr.State([])
262
+ # queue value for image refresh, origin image, mask, logits, painted image
263
+
264
+
265
+
266
+ with gr.Row():
267
+
268
+ # for user video input
269
+ with gr.Column(scale=1.0):
270
+ video_input = gr.Video().style(height=720)
271
+
272
+ # listen to the user action for play and pause input video
273
+ video_input.play(fn=play_video, inputs=play_state, outputs=play_state, scroll_to_output=True, show_progress=True)
274
+ video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
275
+
276
+
277
+ with gr.Row(scale=1):
278
+ # put the template frame under the radio button
279
+ with gr.Column(scale=0.5):
280
+ # click points settins, negative or positive, mode continuous or single
281
+ with gr.Row():
282
+ with gr.Row(scale=0.5):
283
+ point_prompt = gr.Radio(
284
+ choices=["Positive", "Negative"],
285
+ value="Positive",
286
+ label="Point Prompt",
287
+ interactive=True)
288
+ click_mode = gr.Radio(
289
+ choices=["Continuous", "Single"],
290
+ value="Continuous",
291
+ label="Clicking Mode",
292
+ interactive=True)
293
+ with gr.Row(scale=0.5):
294
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
295
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
296
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
297
+ with gr.Column():
298
+ template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
299
+
300
+
301
+
302
+ with gr.Column(scale=0.5):
303
+
304
+
305
+ # for intermedia result check and correction
306
+ # intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
307
+ video_output = gr.Video().style(height=360)
308
+ tracking_video_predict_button = gr.Button(value="Tracking")
309
+
310
+ image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
311
+ image_selection_slider = gr.Slider(minimum=0, maximum=100, step=0.1, value=0, label="Image Selection", interactive=True)
312
+ correct_track_button = gr.Button(value="Interactive Correction")
313
+
314
+ template_frame.select(
315
+ fn=sam_refine,
316
+ inputs=[
317
+ origin_image, point_prompt, click_state, logits
318
+ ],
319
+ outputs=[
320
+ template_frame, click_state, logits, template_mask
321
+ ]
322
+ )
323
+
324
+ template_select_button.click(
325
+ fn=get_frames_from_video,
326
+ inputs=[
327
+ video_input,
328
+ play_state
329
+ ],
330
+ # outputs=[video_state, template_frame, origin_image, fps, video_name],
331
+ outputs=[video_state, template_frame, origin_image, fps],
332
+ )
333
+
334
+ tracking_video_predict_button.click(
335
+ fn=vos_tracking_video,
336
+ inputs=[video_state, template_mask, fps, video_input],
337
+ outputs=[video_output, painted_images, masks, logits]
338
+ )
339
+ image_selection_slider.release(fn=vos_tracking_image,
340
+ inputs=[image_selection_slider, painted_images], outputs=[image_output, select_correction_frame], api_name="select_image")
341
+ # correction
342
+ image_output.select(
343
+ fn=interactive_correction,
344
+ inputs=[video_state, point_prompt, click_state, select_correction_frame],
345
+ outputs=[image_output, corrected_state]
346
+ )
347
+ correct_track_button.click(
348
+ fn=correct_track,
349
+ inputs=[video_state, select_correction_frame, corrected_state, masks, logits, painted_images, fps,video_input],
350
+ outputs=[video_output, painted_images, logits, masks ]
351
+ )
352
+
353
+
354
+
355
+ # clear input
356
+ video_input.clear(
357
+ lambda: ([], [], [[], [], []],
358
+ None, "", "", "", "", "", "", "", [[],[]],
359
+ None),
360
+ [],
361
+ [ state, play_state, video_state,
362
+ template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
363
+ select_correction_frame],
364
+ queue=False,
365
+ show_progress=False
366
+ )
367
+ clear_button_image.click(
368
+ fn=model_reset
369
+ )
370
+ clear_button_clike.click(
371
+ lambda: ([[],[]]),
372
+ [],
373
+ [click_state],
374
+ queue=False,
375
+ show_progress=False
376
+ )
377
+ iface.queue(concurrency_count=1)
378
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
379
+
380
+
381
+
app_test.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def update_iframe(slider_value):
4
+ return f'''
5
+ <script>
6
+ window.addEventListener('message', function(event) {{
7
+ if (event.data.sliderValue !== undefined) {{
8
+ var iframe = document.getElementById("text_iframe");
9
+ iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
10
+ }}
11
+ }}, false);
12
+ </script>
13
+ <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
14
+ '''
15
+
16
+ iface = gr.Interface(
17
+ fn=update_iframe,
18
+ inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
19
+ outputs=gr.outputs.HTML(),
20
+ allow_flagging=False,
21
+ )
22
+
23
+ iface.launch(server_name='0.0.0.0', server_port=12212)
assets/demo_version_1.MP4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b61b54bc6eb0d0f7416f95aa3cd6a48d850ca7473022ec1aff48310911b0233
3
+ size 27053146
assets/inpainting.gif ADDED

Git LFS Details

  • SHA256: 5e99bd697bccaed7a0dded7f00855f222031b7dcefd8f64f22f374fcdab390d2
  • Pointer size: 133 Bytes
  • Size of remote file: 22.2 MB
assets/poster_demo_version_1.png ADDED
assets/qingming.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58b34bbce0bd0a18ab5fc5450d4046e1cfc6bd55c508046695545819d8fc46dc
3
+ size 4483842
demo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
2
+
3
+ # For image
4
+
5
+ def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
6
+ SegAutoMaskPredictor().image_predict(
7
+ source=image_path,
8
+ model_type=model_type, # vit_l, vit_h, vit_b
9
+ points_per_side=points_per_side,
10
+ points_per_batch=points_per_batch,
11
+ min_area=min_area,
12
+ output_path="output.png",
13
+ show=False,
14
+ save=True,
15
+ )
16
+ return "output.png"
17
+
18
+
19
+ # For video
20
+
21
+ def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
22
+ SegAutoMaskPredictor().video_predict(
23
+ source=video_path,
24
+ model_type=model_type, # vit_l, vit_h, vit_b
25
+ points_per_side=points_per_side,
26
+ points_per_batch=points_per_batch,
27
+ min_area=min_area,
28
+ output_path="output.mp4",
29
+ )
30
+ return "output.mp4"
31
+
32
+
33
+ # For manuel box and point selection
34
+
35
+ def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
36
+ SegManualMaskPredictor().image_predict(
37
+ source=image_path,
38
+ model_type=model_type, # vit_l, vit_h, vit_b
39
+ input_point=input_point,
40
+ input_label=input_label,
41
+ input_box=input_box,
42
+ multimask_output=multimask_output,
43
+ random_color=random_color,
44
+ output_path="output.png",
45
+ show=False,
46
+ save=True,
47
+ )
48
+ return "output.png"
49
+
50
+
51
+ # For sahi sliced prediction
52
+
53
+ def sahi_autoseg_app(
54
+ image_path,
55
+ sam_model_type,
56
+ detection_model_type,
57
+ detection_model_path,
58
+ conf_th,
59
+ image_size,
60
+ slice_height,
61
+ slice_width,
62
+ overlap_height_ratio,
63
+ overlap_width_ratio,
64
+ ):
65
+ boxes = sahi_sliced_predict(
66
+ image_path=image_path,
67
+ detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
68
+ detection_model_path=detection_model_path,
69
+ conf_th=conf_th,
70
+ image_size=image_size,
71
+ slice_height=slice_height,
72
+ slice_width=slice_width,
73
+ overlap_height_ratio=overlap_height_ratio,
74
+ overlap_width_ratio=overlap_width_ratio,
75
+ )
76
+
77
+ SahiAutoSegmentation().predict(
78
+ source=image_path,
79
+ model_type=sam_model_type,
80
+ input_box=boxes,
81
+ multimask_output=False,
82
+ random_color=False,
83
+ show=False,
84
+ save=True,
85
+ )
86
+
87
+ return "output.png"
images/groceries.jpg ADDED
images/mask_painter.png ADDED
images/painter_input_image.jpg ADDED
images/painter_input_mask.jpg ADDED
images/painter_output_image.png ADDED
images/painter_output_image__.png ADDED
images/point_painter.png ADDED
images/point_painter_1.png ADDED
images/point_painter_2.png ADDED
images/truck.jpg ADDED
images/truck_both.jpg ADDED
images/truck_mask.jpg ADDED
images/truck_point.jpg ADDED
inpainter/.DS_Store ADDED
Binary file (6.15 kB). View file
inpainter/base_inpainter.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import yaml
7
+ import cv2
8
+ import importlib
9
+ import numpy as np
10
+ from util.tensor_util import resize_frames, resize_masks
11
+
12
+
13
+ class BaseInpainter:
14
+ def __init__(self, E2FGVI_checkpoint, device) -> None:
15
+ """
16
+ E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
17
+ """
18
+ net = importlib.import_module('model.e2fgvi_hq')
19
+ self.model = net.InpaintGenerator().to(device)
20
+ self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
21
+ self.model.eval()
22
+ self.device = device
23
+ # load configurations
24
+ with open("inpainter/config/config.yaml", 'r') as stream:
25
+ config = yaml.safe_load(stream)
26
+ self.neighbor_stride = config['neighbor_stride']
27
+ self.num_ref = config['num_ref']
28
+ self.step = config['step']
29
+
30
+ # sample reference frames from the whole video
31
+ def get_ref_index(self, f, neighbor_ids, length):
32
+ ref_index = []
33
+ if self.num_ref == -1:
34
+ for i in range(0, length, self.step):
35
+ if i not in neighbor_ids:
36
+ ref_index.append(i)
37
+ else:
38
+ start_idx = max(0, f - self.step * (self.num_ref // 2))
39
+ end_idx = min(length, f + self.step * (self.num_ref // 2))
40
+ for i in range(start_idx, end_idx + 1, self.step):
41
+ if i not in neighbor_ids:
42
+ if len(ref_index) > self.num_ref:
43
+ break
44
+ ref_index.append(i)
45
+ return ref_index
46
+
47
+ def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
48
+ """
49
+ frames: numpy array, T, H, W, 3
50
+ masks: numpy array, T, H, W
51
+ dilate_radius: radius when applying dilation on masks
52
+ ratio: down-sample ratio
53
+
54
+ Output:
55
+ inpainted_frames: numpy array, T, H, W, 3
56
+ """
57
+ assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
58
+ assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
59
+ masks = masks.copy()
60
+ masks = np.clip(masks, 0, 1)
61
+ kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
62
+ masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
63
+
64
+ T, H, W = masks.shape
65
+ # size: (w, h)
66
+ if ratio == 1:
67
+ size = None
68
+ else:
69
+ size = (int(W*ratio), int(H*ratio))
70
+
71
+ masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
72
+ binary_masks = resize_masks(masks, size)
73
+ frames = resize_frames(frames, size) # T, H, W, 3
74
+ # frames and binary_masks are numpy arrays
75
+
76
+ h, w = frames.shape[1:3]
77
+ video_length = T
78
+
79
+ # convert to tensor
80
+ imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
81
+ masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
82
+
83
+ imgs, masks = imgs.to(self.device), masks.to(self.device)
84
+ comp_frames = [None] * video_length
85
+
86
+ for f in range(0, video_length, self.neighbor_stride):
87
+ neighbor_ids = [
88
+ i for i in range(max(0, f - self.neighbor_stride),
89
+ min(video_length, f + self.neighbor_stride + 1))
90
+ ]
91
+ ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
92
+ selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
93
+ selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
94
+ with torch.no_grad():
95
+ masked_imgs = selected_imgs * (1 - selected_masks)
96
+ mod_size_h = 60
97
+ mod_size_w = 108
98
+ h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
99
+ w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
100
+ masked_imgs = torch.cat(
101
+ [masked_imgs, torch.flip(masked_imgs, [3])],
102
+ 3)[:, :, :, :h + h_pad, :]
103
+ masked_imgs = torch.cat(
104
+ [masked_imgs, torch.flip(masked_imgs, [4])],
105
+ 4)[:, :, :, :, :w + w_pad]
106
+ pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
107
+ pred_imgs = pred_imgs[:, :, :h, :w]
108
+ pred_imgs = (pred_imgs + 1) / 2
109
+ pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
110
+ for i in range(len(neighbor_ids)):
111
+ idx = neighbor_ids[i]
112
+ img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
113
+ 1 - binary_masks[idx])
114
+ if comp_frames[idx] is None:
115
+ comp_frames[idx] = img
116
+ else:
117
+ comp_frames[idx] = comp_frames[idx].astype(
118
+ np.float32) * 0.5 + img.astype(np.float32) * 0.5
119
+
120
+ inpainted_frames = np.stack(comp_frames, 0)
121
+ return inpainted_frames.astype(np.uint8)
122
+
123
+ if __name__ == '__main__':
124
+
125
+ frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
126
+ frame_path.sort()
127
+ mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
128
+ mask_path.sort()
129
+ save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
130
+
131
+ if not os.path.exists(save_path):
132
+ os.mkdir(save_path)
133
+
134
+ frames = []
135
+ masks = []
136
+ for fid, mid in zip(frame_path, mask_path):
137
+ frames.append(Image.open(fid).convert('RGB'))
138
+ masks.append(Image.open(mid).convert('P'))
139
+
140
+ frames = np.stack(frames, 0)
141
+ masks = np.stack(masks, 0)
142
+
143
+ # ----------------------------------------------
144
+ # how to use
145
+ # ----------------------------------------------
146
+ # 1/3: set checkpoint and device
147
+ checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
148
+ device = 'cuda:6'
149
+ # 2/3: initialise inpainter
150
+ base_inpainter = BaseInpainter(checkpoint, device)
151
+ # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
152
+ # ratio: (0, 1], ratio for down sample, default value is 1
153
+ inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=1) # numpy array, T, H, W, 3
154
+ # ----------------------------------------------
155
+ # end
156
+ # ----------------------------------------------
157
+ # save
158
+ for ti, inpainted_frame in enumerate(inpainted_frames):
159
+ frame = Image.fromarray(inpainted_frame).convert('RGB')
160
+ frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
inpainter/config/config.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ # config info for E2FGVI
2
+ neighbor_stride: 5
3
+ num_ref: -1
4
+ step: 10
inpainter/model/e2fgvi.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Towards An End-to-End Framework for Video Inpainting
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from model.modules.flow_comp import SPyNet
9
+ from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
10
+ from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
11
+ from model.modules.spectral_norm import spectral_norm as _spectral_norm
12
+
13
+
14
+ class BaseNetwork(nn.Module):
15
+ def __init__(self):
16
+ super(BaseNetwork, self).__init__()
17
+
18
+ def print_network(self):
19
+ if isinstance(self, list):
20
+ self = self[0]
21
+ num_params = 0
22
+ for param in self.parameters():
23
+ num_params += param.numel()
24
+ print(
25
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
26
+ 'To see the architecture, do print(network).' %
27
+ (type(self).__name__, num_params / 1000000))
28
+
29
+ def init_weights(self, init_type='normal', gain=0.02):
30
+ '''
31
+ initialize network's weights
32
+ init_type: normal | xavier | kaiming | orthogonal
33
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
34
+ '''
35
+ def init_func(m):
36
+ classname = m.__class__.__name__
37
+ if classname.find('InstanceNorm2d') != -1:
38
+ if hasattr(m, 'weight') and m.weight is not None:
39
+ nn.init.constant_(m.weight.data, 1.0)
40
+ if hasattr(m, 'bias') and m.bias is not None:
41
+ nn.init.constant_(m.bias.data, 0.0)
42
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
43
+ or classname.find('Linear') != -1):
44
+ if init_type == 'normal':
45
+ nn.init.normal_(m.weight.data, 0.0, gain)
46
+ elif init_type == 'xavier':
47
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
48
+ elif init_type == 'xavier_uniform':
49
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
50
+ elif init_type == 'kaiming':
51
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
52
+ elif init_type == 'orthogonal':
53
+ nn.init.orthogonal_(m.weight.data, gain=gain)
54
+ elif init_type == 'none': # uses pytorch's default init method
55
+ m.reset_parameters()
56
+ else:
57
+ raise NotImplementedError(
58
+ 'initialization method [%s] is not implemented' %
59
+ init_type)
60
+ if hasattr(m, 'bias') and m.bias is not None:
61
+ nn.init.constant_(m.bias.data, 0.0)
62
+
63
+ self.apply(init_func)
64
+
65
+ # propagate to children
66
+ for m in self.children():
67
+ if hasattr(m, 'init_weights'):
68
+ m.init_weights(init_type, gain)
69
+
70
+
71
+ class Encoder(nn.Module):
72
+ def __init__(self):
73
+ super(Encoder, self).__init__()
74
+ self.group = [1, 2, 4, 8, 1]
75
+ self.layers = nn.ModuleList([
76
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
77
+ nn.LeakyReLU(0.2, inplace=True),
78
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
81
+ nn.LeakyReLU(0.2, inplace=True),
82
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
87
+ nn.LeakyReLU(0.2, inplace=True),
88
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
91
+ nn.LeakyReLU(0.2, inplace=True),
92
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
93
+ nn.LeakyReLU(0.2, inplace=True)
94
+ ])
95
+
96
+ def forward(self, x):
97
+ bt, c, h, w = x.size()
98
+ h, w = h // 4, w // 4
99
+ out = x
100
+ for i, layer in enumerate(self.layers):
101
+ if i == 8:
102
+ x0 = out
103
+ if i > 8 and i % 2 == 0:
104
+ g = self.group[(i - 8) // 2]
105
+ x = x0.view(bt, g, -1, h, w)
106
+ o = out.view(bt, g, -1, h, w)
107
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
108
+ out = layer(out)
109
+ return out
110
+
111
+
112
+ class deconv(nn.Module):
113
+ def __init__(self,
114
+ input_channel,
115
+ output_channel,
116
+ kernel_size=3,
117
+ padding=0):
118
+ super().__init__()
119
+ self.conv = nn.Conv2d(input_channel,
120
+ output_channel,
121
+ kernel_size=kernel_size,
122
+ stride=1,
123
+ padding=padding)
124
+
125
+ def forward(self, x):
126
+ x = F.interpolate(x,
127
+ scale_factor=2,
128
+ mode='bilinear',
129
+ align_corners=True)
130
+ return self.conv(x)
131
+
132
+
133
+ class InpaintGenerator(BaseNetwork):
134
+ def __init__(self, init_weights=True):
135
+ super(InpaintGenerator, self).__init__()
136
+ channel = 256
137
+ hidden = 512
138
+
139
+ # encoder
140
+ self.encoder = Encoder()
141
+
142
+ # decoder
143
+ self.decoder = nn.Sequential(
144
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
145
+ nn.LeakyReLU(0.2, inplace=True),
146
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
147
+ nn.LeakyReLU(0.2, inplace=True),
148
+ deconv(64, 64, kernel_size=3, padding=1),
149
+ nn.LeakyReLU(0.2, inplace=True),
150
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
151
+
152
+ # feature propagation module
153
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
154
+
155
+ # soft split and soft composition
156
+ kernel_size = (7, 7)
157
+ padding = (3, 3)
158
+ stride = (3, 3)
159
+ output_size = (60, 108)
160
+ t2t_params = {
161
+ 'kernel_size': kernel_size,
162
+ 'stride': stride,
163
+ 'padding': padding,
164
+ 'output_size': output_size
165
+ }
166
+ self.ss = SoftSplit(channel // 2,
167
+ hidden,
168
+ kernel_size,
169
+ stride,
170
+ padding,
171
+ t2t_param=t2t_params)
172
+ self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
173
+ stride, padding)
174
+
175
+ n_vecs = 1
176
+ for i, d in enumerate(kernel_size):
177
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
178
+ (d - 1) - 1) / stride[i] + 1)
179
+
180
+ blocks = []
181
+ depths = 8
182
+ num_heads = [4] * depths
183
+ window_size = [(5, 9)] * depths
184
+ focal_windows = [(5, 9)] * depths
185
+ focal_levels = [2] * depths
186
+ pool_method = "fc"
187
+
188
+ for i in range(depths):
189
+ blocks.append(
190
+ TemporalFocalTransformerBlock(dim=hidden,
191
+ num_heads=num_heads[i],
192
+ window_size=window_size[i],
193
+ focal_level=focal_levels[i],
194
+ focal_window=focal_windows[i],
195
+ n_vecs=n_vecs,
196
+ t2t_params=t2t_params,
197
+ pool_method=pool_method))
198
+ self.transformer = nn.Sequential(*blocks)
199
+
200
+ if init_weights:
201
+ self.init_weights()
202
+ # Need to initial the weights of MSDeformAttn specifically
203
+ for m in self.modules():
204
+ if isinstance(m, SecondOrderDeformableAlignment):
205
+ m.init_offset()
206
+
207
+ # flow completion network
208
+ self.update_spynet = SPyNet()
209
+
210
+ def forward_bidirect_flow(self, masked_local_frames):
211
+ b, l_t, c, h, w = masked_local_frames.size()
212
+
213
+ # compute forward and backward flows of masked frames
214
+ masked_local_frames = F.interpolate(masked_local_frames.view(
215
+ -1, c, h, w),
216
+ scale_factor=1 / 4,
217
+ mode='bilinear',
218
+ align_corners=True,
219
+ recompute_scale_factor=True)
220
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
221
+ w // 4)
222
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
223
+ -1, c, h // 4, w // 4)
224
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
225
+ -1, c, h // 4, w // 4)
226
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
227
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
228
+
229
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
230
+ w // 4)
231
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
232
+ w // 4)
233
+
234
+ return pred_flows_forward, pred_flows_backward
235
+
236
+ def forward(self, masked_frames, num_local_frames):
237
+ l_t = num_local_frames
238
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
239
+
240
+ # normalization before feeding into the flow completion module
241
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
242
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
243
+
244
+ # extracting features and performing the feature propagation on local features
245
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
246
+ _, c, h, w = enc_feat.size()
247
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
248
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
249
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
250
+ pred_flows[1])
251
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
252
+
253
+ # content hallucination through stacking multiple temporal focal transformer blocks
254
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
255
+ trans_feat = self.transformer(trans_feat)
256
+ trans_feat = self.sc(trans_feat, t)
257
+ trans_feat = trans_feat.view(b, t, -1, h, w)
258
+ enc_feat = enc_feat + trans_feat
259
+
260
+ # decode frames from features
261
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
262
+ output = torch.tanh(output)
263
+ return output, pred_flows
264
+
265
+
266
+ # ######################################################################
267
+ # Discriminator for Temporal Patch GAN
268
+ # ######################################################################
269
+
270
+
271
+ class Discriminator(BaseNetwork):
272
+ def __init__(self,
273
+ in_channels=3,
274
+ use_sigmoid=False,
275
+ use_spectral_norm=True,
276
+ init_weights=True):
277
+ super(Discriminator, self).__init__()
278
+ self.use_sigmoid = use_sigmoid
279
+ nf = 32
280
+
281
+ self.conv = nn.Sequential(
282
+ spectral_norm(
283
+ nn.Conv3d(in_channels=in_channels,
284
+ out_channels=nf * 1,
285
+ kernel_size=(3, 5, 5),
286
+ stride=(1, 2, 2),
287
+ padding=1,
288
+ bias=not use_spectral_norm), use_spectral_norm),
289
+ # nn.InstanceNorm2d(64, track_running_stats=False),
290
+ nn.LeakyReLU(0.2, inplace=True),
291
+ spectral_norm(
292
+ nn.Conv3d(nf * 1,
293
+ nf * 2,
294
+ kernel_size=(3, 5, 5),
295
+ stride=(1, 2, 2),
296
+ padding=(1, 2, 2),
297
+ bias=not use_spectral_norm), use_spectral_norm),
298
+ # nn.InstanceNorm2d(128, track_running_stats=False),
299
+ nn.LeakyReLU(0.2, inplace=True),
300
+ spectral_norm(
301
+ nn.Conv3d(nf * 2,
302
+ nf * 4,
303
+ kernel_size=(3, 5, 5),
304
+ stride=(1, 2, 2),
305
+ padding=(1, 2, 2),
306
+ bias=not use_spectral_norm), use_spectral_norm),
307
+ # nn.InstanceNorm2d(256, track_running_stats=False),
308
+ nn.LeakyReLU(0.2, inplace=True),
309
+ spectral_norm(
310
+ nn.Conv3d(nf * 4,
311
+ nf * 4,
312
+ kernel_size=(3, 5, 5),
313
+ stride=(1, 2, 2),
314
+ padding=(1, 2, 2),
315
+ bias=not use_spectral_norm), use_spectral_norm),
316
+ # nn.InstanceNorm2d(256, track_running_stats=False),
317
+ nn.LeakyReLU(0.2, inplace=True),
318
+ spectral_norm(
319
+ nn.Conv3d(nf * 4,
320
+ nf * 4,
321
+ kernel_size=(3, 5, 5),
322
+ stride=(1, 2, 2),
323
+ padding=(1, 2, 2),
324
+ bias=not use_spectral_norm), use_spectral_norm),
325
+ # nn.InstanceNorm2d(256, track_running_stats=False),
326
+ nn.LeakyReLU(0.2, inplace=True),
327
+ nn.Conv3d(nf * 4,
328
+ nf * 4,
329
+ kernel_size=(3, 5, 5),
330
+ stride=(1, 2, 2),
331
+ padding=(1, 2, 2)))
332
+
333
+ if init_weights:
334
+ self.init_weights()
335
+
336
+ def forward(self, xs):
337
+ # T, C, H, W = xs.shape (old)
338
+ # B, T, C, H, W (new)
339
+ xs_t = torch.transpose(xs, 1, 2)
340
+ feat = self.conv(xs_t)
341
+ if self.use_sigmoid:
342
+ feat = torch.sigmoid(feat)
343
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
344
+ return out
345
+
346
+
347
+ def spectral_norm(module, mode=True):
348
+ if mode:
349
+ return _spectral_norm(module)
350
+ return module
inpainter/model/e2fgvi_hq.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Towards An End-to-End Framework for Video Inpainting
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from model.modules.flow_comp import SPyNet
9
+ from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
10
+ from model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
11
+ from model.modules.spectral_norm import spectral_norm as _spectral_norm
12
+
13
+
14
+ class BaseNetwork(nn.Module):
15
+ def __init__(self):
16
+ super(BaseNetwork, self).__init__()
17
+
18
+ def print_network(self):
19
+ if isinstance(self, list):
20
+ self = self[0]
21
+ num_params = 0
22
+ for param in self.parameters():
23
+ num_params += param.numel()
24
+ print(
25
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
26
+ 'To see the architecture, do print(network).' %
27
+ (type(self).__name__, num_params / 1000000))
28
+
29
+ def init_weights(self, init_type='normal', gain=0.02):
30
+ '''
31
+ initialize network's weights
32
+ init_type: normal | xavier | kaiming | orthogonal
33
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
34
+ '''
35
+ def init_func(m):
36
+ classname = m.__class__.__name__
37
+ if classname.find('InstanceNorm2d') != -1:
38
+ if hasattr(m, 'weight') and m.weight is not None:
39
+ nn.init.constant_(m.weight.data, 1.0)
40
+ if hasattr(m, 'bias') and m.bias is not None:
41
+ nn.init.constant_(m.bias.data, 0.0)
42
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
43
+ or classname.find('Linear') != -1):
44
+ if init_type == 'normal':
45
+ nn.init.normal_(m.weight.data, 0.0, gain)
46
+ elif init_type == 'xavier':
47
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
48
+ elif init_type == 'xavier_uniform':
49
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
50
+ elif init_type == 'kaiming':
51
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
52
+ elif init_type == 'orthogonal':
53
+ nn.init.orthogonal_(m.weight.data, gain=gain)
54
+ elif init_type == 'none': # uses pytorch's default init method
55
+ m.reset_parameters()
56
+ else:
57
+ raise NotImplementedError(
58
+ 'initialization method [%s] is not implemented' %
59
+ init_type)
60
+ if hasattr(m, 'bias') and m.bias is not None:
61
+ nn.init.constant_(m.bias.data, 0.0)
62
+
63
+ self.apply(init_func)
64
+
65
+ # propagate to children
66
+ for m in self.children():
67
+ if hasattr(m, 'init_weights'):
68
+ m.init_weights(init_type, gain)
69
+
70
+
71
+ class Encoder(nn.Module):
72
+ def __init__(self):
73
+ super(Encoder, self).__init__()
74
+ self.group = [1, 2, 4, 8, 1]
75
+ self.layers = nn.ModuleList([
76
+ nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
77
+ nn.LeakyReLU(0.2, inplace=True),
78
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
81
+ nn.LeakyReLU(0.2, inplace=True),
82
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
87
+ nn.LeakyReLU(0.2, inplace=True),
88
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
91
+ nn.LeakyReLU(0.2, inplace=True),
92
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
93
+ nn.LeakyReLU(0.2, inplace=True)
94
+ ])
95
+
96
+ def forward(self, x):
97
+ bt, c, _, _ = x.size()
98
+ # h, w = h//4, w//4
99
+ out = x
100
+ for i, layer in enumerate(self.layers):
101
+ if i == 8:
102
+ x0 = out
103
+ _, _, h, w = x0.size()
104
+ if i > 8 and i % 2 == 0:
105
+ g = self.group[(i - 8) // 2]
106
+ x = x0.view(bt, g, -1, h, w)
107
+ o = out.view(bt, g, -1, h, w)
108
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
109
+ out = layer(out)
110
+ return out
111
+
112
+
113
+ class deconv(nn.Module):
114
+ def __init__(self,
115
+ input_channel,
116
+ output_channel,
117
+ kernel_size=3,
118
+ padding=0):
119
+ super().__init__()
120
+ self.conv = nn.Conv2d(input_channel,
121
+ output_channel,
122
+ kernel_size=kernel_size,
123
+ stride=1,
124
+ padding=padding)
125
+
126
+ def forward(self, x):
127
+ x = F.interpolate(x,
128
+ scale_factor=2,
129
+ mode='bilinear',
130
+ align_corners=True)
131
+ return self.conv(x)
132
+
133
+
134
+ class InpaintGenerator(BaseNetwork):
135
+ def __init__(self, init_weights=True):
136
+ super(InpaintGenerator, self).__init__()
137
+ channel = 256
138
+ hidden = 512
139
+
140
+ # encoder
141
+ self.encoder = Encoder()
142
+
143
+ # decoder
144
+ self.decoder = nn.Sequential(
145
+ deconv(channel // 2, 128, kernel_size=3, padding=1),
146
+ nn.LeakyReLU(0.2, inplace=True),
147
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
148
+ nn.LeakyReLU(0.2, inplace=True),
149
+ deconv(64, 64, kernel_size=3, padding=1),
150
+ nn.LeakyReLU(0.2, inplace=True),
151
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
152
+
153
+ # feature propagation module
154
+ self.feat_prop_module = BidirectionalPropagation(channel // 2)
155
+
156
+ # soft split and soft composition
157
+ kernel_size = (7, 7)
158
+ padding = (3, 3)
159
+ stride = (3, 3)
160
+ output_size = (60, 108)
161
+ t2t_params = {
162
+ 'kernel_size': kernel_size,
163
+ 'stride': stride,
164
+ 'padding': padding
165
+ }
166
+ self.ss = SoftSplit(channel // 2,
167
+ hidden,
168
+ kernel_size,
169
+ stride,
170
+ padding,
171
+ t2t_param=t2t_params)
172
+ self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
173
+
174
+ n_vecs = 1
175
+ for i, d in enumerate(kernel_size):
176
+ n_vecs *= int((output_size[i] + 2 * padding[i] -
177
+ (d - 1) - 1) / stride[i] + 1)
178
+
179
+ blocks = []
180
+ depths = 8
181
+ num_heads = [4] * depths
182
+ window_size = [(5, 9)] * depths
183
+ focal_windows = [(5, 9)] * depths
184
+ focal_levels = [2] * depths
185
+ pool_method = "fc"
186
+
187
+ for i in range(depths):
188
+ blocks.append(
189
+ TemporalFocalTransformerBlock(dim=hidden,
190
+ num_heads=num_heads[i],
191
+ window_size=window_size[i],
192
+ focal_level=focal_levels[i],
193
+ focal_window=focal_windows[i],
194
+ n_vecs=n_vecs,
195
+ t2t_params=t2t_params,
196
+ pool_method=pool_method))
197
+ self.transformer = nn.Sequential(*blocks)
198
+
199
+ if init_weights:
200
+ self.init_weights()
201
+ # Need to initial the weights of MSDeformAttn specifically
202
+ for m in self.modules():
203
+ if isinstance(m, SecondOrderDeformableAlignment):
204
+ m.init_offset()
205
+
206
+ # flow completion network
207
+ self.update_spynet = SPyNet()
208
+
209
+ def forward_bidirect_flow(self, masked_local_frames):
210
+ b, l_t, c, h, w = masked_local_frames.size()
211
+
212
+ # compute forward and backward flows of masked frames
213
+ masked_local_frames = F.interpolate(masked_local_frames.view(
214
+ -1, c, h, w),
215
+ scale_factor=1 / 4,
216
+ mode='bilinear',
217
+ align_corners=True,
218
+ recompute_scale_factor=True)
219
+ masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
220
+ w // 4)
221
+ mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
222
+ -1, c, h // 4, w // 4)
223
+ mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
224
+ -1, c, h // 4, w // 4)
225
+ pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
226
+ pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
227
+
228
+ pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
229
+ w // 4)
230
+ pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
231
+ w // 4)
232
+
233
+ return pred_flows_forward, pred_flows_backward
234
+
235
+ def forward(self, masked_frames, num_local_frames):
236
+ l_t = num_local_frames
237
+ b, t, ori_c, ori_h, ori_w = masked_frames.size()
238
+
239
+ # normalization before feeding into the flow completion module
240
+ masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
241
+ pred_flows = self.forward_bidirect_flow(masked_local_frames)
242
+
243
+ # extracting features and performing the feature propagation on local features
244
+ enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
245
+ _, c, h, w = enc_feat.size()
246
+ fold_output_size = (h, w)
247
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
248
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
249
+ local_feat = self.feat_prop_module(local_feat, pred_flows[0],
250
+ pred_flows[1])
251
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
252
+
253
+ # content hallucination through stacking multiple temporal focal transformer blocks
254
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
255
+ trans_feat = self.transformer([trans_feat, fold_output_size])
256
+ trans_feat = self.sc(trans_feat[0], t, fold_output_size)
257
+ trans_feat = trans_feat.view(b, t, -1, h, w)
258
+ enc_feat = enc_feat + trans_feat
259
+
260
+ # decode frames from features
261
+ output = self.decoder(enc_feat.view(b * t, c, h, w))
262
+ output = torch.tanh(output)
263
+ return output, pred_flows
264
+
265
+
266
+ # ######################################################################
267
+ # Discriminator for Temporal Patch GAN
268
+ # ######################################################################
269
+
270
+
271
+ class Discriminator(BaseNetwork):
272
+ def __init__(self,
273
+ in_channels=3,
274
+ use_sigmoid=False,
275
+ use_spectral_norm=True,
276
+ init_weights=True):
277
+ super(Discriminator, self).__init__()
278
+ self.use_sigmoid = use_sigmoid
279
+ nf = 32
280
+
281
+ self.conv = nn.Sequential(
282
+ spectral_norm(
283
+ nn.Conv3d(in_channels=in_channels,
284
+ out_channels=nf * 1,
285
+ kernel_size=(3, 5, 5),
286
+ stride=(1, 2, 2),
287
+ padding=1,
288
+ bias=not use_spectral_norm), use_spectral_norm),
289
+ # nn.InstanceNorm2d(64, track_running_stats=False),
290
+ nn.LeakyReLU(0.2, inplace=True),
291
+ spectral_norm(
292
+ nn.Conv3d(nf * 1,
293
+ nf * 2,
294
+ kernel_size=(3, 5, 5),
295
+ stride=(1, 2, 2),
296
+ padding=(1, 2, 2),
297
+ bias=not use_spectral_norm), use_spectral_norm),
298
+ # nn.InstanceNorm2d(128, track_running_stats=False),
299
+ nn.LeakyReLU(0.2, inplace=True),
300
+ spectral_norm(
301
+ nn.Conv3d(nf * 2,
302
+ nf * 4,
303
+ kernel_size=(3, 5, 5),
304
+ stride=(1, 2, 2),
305
+ padding=(1, 2, 2),
306
+ bias=not use_spectral_norm), use_spectral_norm),
307
+ # nn.InstanceNorm2d(256, track_running_stats=False),
308
+ nn.LeakyReLU(0.2, inplace=True),
309
+ spectral_norm(
310
+ nn.Conv3d(nf * 4,
311
+ nf * 4,
312
+ kernel_size=(3, 5, 5),
313
+ stride=(1, 2, 2),
314
+ padding=(1, 2, 2),
315
+ bias=not use_spectral_norm), use_spectral_norm),
316
+ # nn.InstanceNorm2d(256, track_running_stats=False),
317
+ nn.LeakyReLU(0.2, inplace=True),
318
+ spectral_norm(
319
+ nn.Conv3d(nf * 4,
320
+ nf * 4,
321
+ kernel_size=(3, 5, 5),
322
+ stride=(1, 2, 2),
323
+ padding=(1, 2, 2),
324
+ bias=not use_spectral_norm), use_spectral_norm),
325
+ # nn.InstanceNorm2d(256, track_running_stats=False),
326
+ nn.LeakyReLU(0.2, inplace=True),
327
+ nn.Conv3d(nf * 4,
328
+ nf * 4,
329
+ kernel_size=(3, 5, 5),
330
+ stride=(1, 2, 2),
331
+ padding=(1, 2, 2)))
332
+
333
+ if init_weights:
334
+ self.init_weights()
335
+
336
+ def forward(self, xs):
337
+ # T, C, H, W = xs.shape (old)
338
+ # B, T, C, H, W (new)
339
+ xs_t = torch.transpose(xs, 1, 2)
340
+ feat = self.conv(xs_t)
341
+ if self.use_sigmoid:
342
+ feat = torch.sigmoid(feat)
343
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
344
+ return out
345
+
346
+
347
+ def spectral_norm(module, mode=True):
348
+ if mode:
349
+ return _spectral_norm(module)
350
+ return module
inpainter/model/modules/feat_prop.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
8
+ from mmengine.model import constant_init
9
+
10
+ from model.modules.flow_comp import flow_warp
11
+
12
+
13
+ class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
14
+ """Second-order deformable alignment module."""
15
+ def __init__(self, *args, **kwargs):
16
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
17
+
18
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
19
+
20
+ self.conv_offset = nn.Sequential(
21
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
22
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
23
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
24
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
25
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
27
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
28
+ )
29
+
30
+ self.init_offset()
31
+
32
+ def init_offset(self):
33
+ constant_init(self.conv_offset[-1], val=0, bias=0)
34
+
35
+ def forward(self, x, extra_feat, flow_1, flow_2):
36
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
37
+ out = self.conv_offset(extra_feat)
38
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
39
+
40
+ # offset
41
+ offset = self.max_residue_magnitude * torch.tanh(
42
+ torch.cat((o1, o2), dim=1))
43
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
44
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1,
45
+ offset_1.size(1) // 2, 1,
46
+ 1)
47
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1,
48
+ offset_2.size(1) // 2, 1,
49
+ 1)
50
+ offset = torch.cat([offset_1, offset_2], dim=1)
51
+
52
+ # mask
53
+ mask = torch.sigmoid(mask)
54
+
55
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
56
+ self.stride, self.padding,
57
+ self.dilation, self.groups,
58
+ self.deform_groups)
59
+
60
+
61
+ class BidirectionalPropagation(nn.Module):
62
+ def __init__(self, channel):
63
+ super(BidirectionalPropagation, self).__init__()
64
+ modules = ['backward_', 'forward_']
65
+ self.deform_align = nn.ModuleDict()
66
+ self.backbone = nn.ModuleDict()
67
+ self.channel = channel
68
+
69
+ for i, module in enumerate(modules):
70
+ self.deform_align[module] = SecondOrderDeformableAlignment(
71
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
72
+
73
+ self.backbone[module] = nn.Sequential(
74
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
75
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
76
+ nn.Conv2d(channel, channel, 3, 1, 1),
77
+ )
78
+
79
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
80
+
81
+ def forward(self, x, flows_backward, flows_forward):
82
+ """
83
+ x shape : [b, t, c, h, w]
84
+ return [b, t, c, h, w]
85
+ """
86
+ b, t, c, h, w = x.shape
87
+ feats = {}
88
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
89
+
90
+ for module_name in ['backward_', 'forward_']:
91
+
92
+ feats[module_name] = []
93
+
94
+ frame_idx = range(0, t)
95
+ flow_idx = range(-1, t - 1)
96
+ mapping_idx = list(range(0, len(feats['spatial'])))
97
+ mapping_idx += mapping_idx[::-1]
98
+
99
+ if 'backward' in module_name:
100
+ frame_idx = frame_idx[::-1]
101
+ flows = flows_backward
102
+ else:
103
+ flows = flows_forward
104
+
105
+ feat_prop = x.new_zeros(b, self.channel, h, w)
106
+ for i, idx in enumerate(frame_idx):
107
+ feat_current = feats['spatial'][mapping_idx[idx]]
108
+
109
+ if i > 0:
110
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
111
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
112
+
113
+ # initialize second-order features
114
+ feat_n2 = torch.zeros_like(feat_prop)
115
+ flow_n2 = torch.zeros_like(flow_n1)
116
+ cond_n2 = torch.zeros_like(cond_n1)
117
+ if i > 1:
118
+ feat_n2 = feats[module_name][-2]
119
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
120
+ flow_n2 = flow_n1 + flow_warp(
121
+ flow_n2, flow_n1.permute(0, 2, 3, 1))
122
+ cond_n2 = flow_warp(feat_n2,
123
+ flow_n2.permute(0, 2, 3, 1))
124
+
125
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
126
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
127
+ feat_prop = self.deform_align[module_name](feat_prop, cond,
128
+ flow_n1,
129
+ flow_n2)
130
+
131
+ feat = [feat_current] + [
132
+ feats[k][idx]
133
+ for k in feats if k not in ['spatial', module_name]
134
+ ] + [feat_prop]
135
+
136
+ feat = torch.cat(feat, dim=1)
137
+ feat_prop = feat_prop + self.backbone[module_name](feat)
138
+ feats[module_name].append(feat_prop)
139
+
140
+ if 'backward' in module_name:
141
+ feats[module_name] = feats[module_name][::-1]
142
+
143
+ outputs = []
144
+ for i in range(0, t):
145
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
146
+ align_feats = torch.cat(align_feats, dim=1)
147
+ outputs.append(self.fusion(align_feats))
148
+
149
+ return torch.stack(outputs, dim=1) + x
inpainter/model/modules/flow_comp.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+ from mmcv.cnn import ConvModule
8
+ from mmengine.runner import load_checkpoint
9
+
10
+
11
+ class FlowCompletionLoss(nn.Module):
12
+ """Flow completion loss"""
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.fix_spynet = SPyNet()
16
+ for p in self.fix_spynet.parameters():
17
+ p.requires_grad = False
18
+
19
+ self.l1_criterion = nn.L1Loss()
20
+
21
+ def forward(self, pred_flows, gt_local_frames):
22
+ b, l_t, c, h, w = gt_local_frames.size()
23
+
24
+ with torch.no_grad():
25
+ # compute gt forward and backward flows
26
+ gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
27
+ scale_factor=1 / 4,
28
+ mode='bilinear',
29
+ align_corners=True,
30
+ recompute_scale_factor=True)
31
+ gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
32
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
33
+ -1, c, h // 4, w // 4)
34
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
35
+ -1, c, h // 4, w // 4)
36
+ gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
37
+ gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
38
+
39
+ # calculate loss for flow completion
40
+ forward_flow_loss = self.l1_criterion(
41
+ pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
42
+ backward_flow_loss = self.l1_criterion(
43
+ pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
44
+ flow_loss = forward_flow_loss + backward_flow_loss
45
+
46
+ return flow_loss
47
+
48
+
49
+ class SPyNet(nn.Module):
50
+ """SPyNet network structure.
51
+ The difference to the SPyNet in [tof.py] is that
52
+ 1. more SPyNetBasicModule is used in this version, and
53
+ 2. no batch normalization is used in this version.
54
+ Paper:
55
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
56
+ Args:
57
+ pretrained (str): path for pre-trained SPyNet. Default: None.
58
+ """
59
+ def __init__(
60
+ self,
61
+ use_pretrain=True,
62
+ pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
63
+ ):
64
+ super().__init__()
65
+
66
+ self.basic_module = nn.ModuleList(
67
+ [SPyNetBasicModule() for _ in range(6)])
68
+
69
+ if use_pretrain:
70
+ if isinstance(pretrained, str):
71
+ print("load pretrained SPyNet...")
72
+ load_checkpoint(self, pretrained, strict=True)
73
+ elif pretrained is not None:
74
+ raise TypeError('[pretrained] should be str or None, '
75
+ f'but got {type(pretrained)}.')
76
+
77
+ self.register_buffer(
78
+ 'mean',
79
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
80
+ self.register_buffer(
81
+ 'std',
82
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
83
+
84
+ def compute_flow(self, ref, supp):
85
+ """Compute flow from ref to supp.
86
+ Note that in this function, the images are already resized to a
87
+ multiple of 32.
88
+ Args:
89
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
90
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
91
+ Returns:
92
+ Tensor: Estimated optical flow: (n, 2, h, w).
93
+ """
94
+ n, _, h, w = ref.size()
95
+
96
+ # normalize the input images
97
+ ref = [(ref - self.mean) / self.std]
98
+ supp = [(supp - self.mean) / self.std]
99
+
100
+ # generate downsampled frames
101
+ for level in range(5):
102
+ ref.append(
103
+ F.avg_pool2d(input=ref[-1],
104
+ kernel_size=2,
105
+ stride=2,
106
+ count_include_pad=False))
107
+ supp.append(
108
+ F.avg_pool2d(input=supp[-1],
109
+ kernel_size=2,
110
+ stride=2,
111
+ count_include_pad=False))
112
+ ref = ref[::-1]
113
+ supp = supp[::-1]
114
+
115
+ # flow computation
116
+ flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
117
+ for level in range(len(ref)):
118
+ if level == 0:
119
+ flow_up = flow
120
+ else:
121
+ flow_up = F.interpolate(input=flow,
122
+ scale_factor=2,
123
+ mode='bilinear',
124
+ align_corners=True) * 2.0
125
+
126
+ # add the residue to the upsampled flow
127
+ flow = flow_up + self.basic_module[level](torch.cat([
128
+ ref[level],
129
+ flow_warp(supp[level],
130
+ flow_up.permute(0, 2, 3, 1).contiguous(),
131
+ padding_mode='border'), flow_up
132
+ ], 1))
133
+
134
+ return flow
135
+
136
+ def forward(self, ref, supp):
137
+ """Forward function of SPyNet.
138
+ This function computes the optical flow from ref to supp.
139
+ Args:
140
+ ref (Tensor): Reference image with shape of (n, 3, h, w).
141
+ supp (Tensor): Supporting image with shape of (n, 3, h, w).
142
+ Returns:
143
+ Tensor: Estimated optical flow: (n, 2, h, w).
144
+ """
145
+
146
+ # upsize to a multiple of 32
147
+ h, w = ref.shape[2:4]
148
+ w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
149
+ h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
150
+ ref = F.interpolate(input=ref,
151
+ size=(h_up, w_up),
152
+ mode='bilinear',
153
+ align_corners=False)
154
+ supp = F.interpolate(input=supp,
155
+ size=(h_up, w_up),
156
+ mode='bilinear',
157
+ align_corners=False)
158
+
159
+ # compute flow, and resize back to the original resolution
160
+ flow = F.interpolate(input=self.compute_flow(ref, supp),
161
+ size=(h, w),
162
+ mode='bilinear',
163
+ align_corners=False)
164
+
165
+ # adjust the flow values
166
+ flow[:, 0, :, :] *= float(w) / float(w_up)
167
+ flow[:, 1, :, :] *= float(h) / float(h_up)
168
+
169
+ return flow
170
+
171
+
172
+ class SPyNetBasicModule(nn.Module):
173
+ """Basic Module for SPyNet.
174
+ Paper:
175
+ Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
176
+ """
177
+ def __init__(self):
178
+ super().__init__()
179
+
180
+ self.basic_module = nn.Sequential(
181
+ ConvModule(in_channels=8,
182
+ out_channels=32,
183
+ kernel_size=7,
184
+ stride=1,
185
+ padding=3,
186
+ norm_cfg=None,
187
+ act_cfg=dict(type='ReLU')),
188
+ ConvModule(in_channels=32,
189
+ out_channels=64,
190
+ kernel_size=7,
191
+ stride=1,
192
+ padding=3,
193
+ norm_cfg=None,
194
+ act_cfg=dict(type='ReLU')),
195
+ ConvModule(in_channels=64,
196
+ out_channels=32,
197
+ kernel_size=7,
198
+ stride=1,
199
+ padding=3,
200
+ norm_cfg=None,
201
+ act_cfg=dict(type='ReLU')),
202
+ ConvModule(in_channels=32,
203
+ out_channels=16,
204
+ kernel_size=7,
205
+ stride=1,
206
+ padding=3,
207
+ norm_cfg=None,
208
+ act_cfg=dict(type='ReLU')),
209
+ ConvModule(in_channels=16,
210
+ out_channels=2,
211
+ kernel_size=7,
212
+ stride=1,
213
+ padding=3,
214
+ norm_cfg=None,
215
+ act_cfg=None))
216
+
217
+ def forward(self, tensor_input):
218
+ """
219
+ Args:
220
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
221
+ 8 channels contain:
222
+ [reference image (3), neighbor image (3), initial flow (2)].
223
+ Returns:
224
+ Tensor: Refined flow with shape (b, 2, h, w)
225
+ """
226
+ return self.basic_module(tensor_input)
227
+
228
+
229
+ # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
230
+ def make_colorwheel():
231
+ """
232
+ Generates a color wheel for optical flow visualization as presented in:
233
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
234
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
235
+
236
+ Code follows the original C++ source code of Daniel Scharstein.
237
+ Code follows the the Matlab source code of Deqing Sun.
238
+
239
+ Returns:
240
+ np.ndarray: Color wheel
241
+ """
242
+
243
+ RY = 15
244
+ YG = 6
245
+ GC = 4
246
+ CB = 11
247
+ BM = 13
248
+ MR = 6
249
+
250
+ ncols = RY + YG + GC + CB + BM + MR
251
+ colorwheel = np.zeros((ncols, 3))
252
+ col = 0
253
+
254
+ # RY
255
+ colorwheel[0:RY, 0] = 255
256
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
257
+ col = col + RY
258
+ # YG
259
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
260
+ colorwheel[col:col + YG, 1] = 255
261
+ col = col + YG
262
+ # GC
263
+ colorwheel[col:col + GC, 1] = 255
264
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
265
+ col = col + GC
266
+ # CB
267
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
268
+ colorwheel[col:col + CB, 2] = 255
269
+ col = col + CB
270
+ # BM
271
+ colorwheel[col:col + BM, 2] = 255
272
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
273
+ col = col + BM
274
+ # MR
275
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
276
+ colorwheel[col:col + MR, 0] = 255
277
+ return colorwheel
278
+
279
+
280
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
281
+ """
282
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
283
+
284
+ According to the C++ source code of Daniel Scharstein
285
+ According to the Matlab source code of Deqing Sun
286
+
287
+ Args:
288
+ u (np.ndarray): Input horizontal flow of shape [H,W]
289
+ v (np.ndarray): Input vertical flow of shape [H,W]
290
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
291
+
292
+ Returns:
293
+ np.ndarray: Flow visualization image of shape [H,W,3]
294
+ """
295
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
296
+ colorwheel = make_colorwheel() # shape [55x3]
297
+ ncols = colorwheel.shape[0]
298
+ rad = np.sqrt(np.square(u) + np.square(v))
299
+ a = np.arctan2(-v, -u) / np.pi
300
+ fk = (a + 1) / 2 * (ncols - 1)
301
+ k0 = np.floor(fk).astype(np.int32)
302
+ k1 = k0 + 1
303
+ k1[k1 == ncols] = 0
304
+ f = fk - k0
305
+ for i in range(colorwheel.shape[1]):
306
+ tmp = colorwheel[:, i]
307
+ col0 = tmp[k0] / 255.0
308
+ col1 = tmp[k1] / 255.0
309
+ col = (1 - f) * col0 + f * col1
310
+ idx = (rad <= 1)
311
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
312
+ col[~idx] = col[~idx] * 0.75 # out of range
313
+ # Note the 2-i => BGR instead of RGB
314
+ ch_idx = 2 - i if convert_to_bgr else i
315
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
316
+ return flow_image
317
+
318
+
319
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
320
+ """
321
+ Expects a two dimensional flow image of shape.
322
+
323
+ Args:
324
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
325
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
326
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
327
+
328
+ Returns:
329
+ np.ndarray: Flow visualization image of shape [H,W,3]
330
+ """
331
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
332
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
333
+ if clip_flow is not None:
334
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
335
+ u = flow_uv[:, :, 0]
336
+ v = flow_uv[:, :, 1]
337
+ rad = np.sqrt(np.square(u) + np.square(v))
338
+ rad_max = np.max(rad)
339
+ epsilon = 1e-5
340
+ u = u / (rad_max + epsilon)
341
+ v = v / (rad_max + epsilon)
342
+ return flow_uv_to_colors(u, v, convert_to_bgr)
343
+
344
+
345
+ def flow_warp(x,
346
+ flow,
347
+ interpolation='bilinear',
348
+ padding_mode='zeros',
349
+ align_corners=True):
350
+ """Warp an image or a feature map with optical flow.
351
+ Args:
352
+ x (Tensor): Tensor with size (n, c, h, w).
353
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
354
+ a two-channel, denoting the width and height relative offsets.
355
+ Note that the values are not normalized to [-1, 1].
356
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
357
+ Default: 'bilinear'.
358
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
359
+ Default: 'zeros'.
360
+ align_corners (bool): Whether align corners. Default: True.
361
+ Returns:
362
+ Tensor: Warped image or feature map.
363
+ """
364
+ if x.size()[-2:] != flow.size()[1:3]:
365
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
366
+ f'flow ({flow.size()[1:3]}) are not the same.')
367
+ _, _, h, w = x.size()
368
+ # create mesh grid
369
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
370
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
371
+ grid.requires_grad = False
372
+
373
+ grid_flow = grid + flow
374
+ # scale grid_flow to [-1,1]
375
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
376
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
377
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
378
+ output = F.grid_sample(x,
379
+ grid_flow,
380
+ mode=interpolation,
381
+ padding_mode=padding_mode,
382
+ align_corners=align_corners)
383
+ return output
384
+
385
+
386
+ def initial_mask_flow(mask):
387
+ """
388
+ mask 1 indicates valid pixel 0 indicates unknown pixel
389
+ """
390
+ B, T, C, H, W = mask.shape
391
+
392
+ # calculate relative position
393
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
394
+
395
+ grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
396
+ abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
397
+ relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
398
+
399
+ abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
400
+ relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
401
+
402
+ # calculate the nearest indices
403
+ pos_up = mask.unsqueeze(3).repeat(
404
+ 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
405
+ relative_pos_y <= H)[None, None, None]
406
+ nearest_indice_up = pos_up.max(dim=4)[1]
407
+
408
+ pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
409
+ None, None, None] * (relative_pos_y <= H)[None, None, None]
410
+ nearest_indice_down = (pos_down).max(dim=4)[1]
411
+
412
+ pos_left = mask.unsqueeze(4).repeat(
413
+ 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
414
+ relative_pos_x <= W)[None, None, None]
415
+ nearest_indice_left = (pos_left).max(dim=5)[1]
416
+
417
+ pos_right = mask.unsqueeze(4).repeat(
418
+ 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
419
+ relative_pos_x <= W)[None, None, None]
420
+ nearest_indice_right = (pos_right).max(dim=5)[1]
421
+
422
+ # NOTE: IMPORTANT !!! depending on how to use this offset
423
+ initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
424
+ initial_offset_down = nearest_indice_down - grid_y[None, None, None]
425
+
426
+ initial_offset_left = -(nearest_indice_left -
427
+ grid_x[None, None, None]).flip(4)
428
+ initial_offset_right = nearest_indice_right - grid_x[None, None, None]
429
+
430
+ # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
431
+ # initial_offset_x = nearest_indice_x - grid_x
432
+
433
+ # handle the boundary cases
434
+ final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
435
+ initial_offset_down > 0) * initial_offset_down
436
+ final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
437
+ initial_offset_up < 0) * initial_offset_up
438
+ final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
439
+ initial_offset_right > 0) * initial_offset_right
440
+ final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
441
+ initial_offset_left < 0) * initial_offset_left
442
+ zero_offset = torch.zeros_like(final_offset_down)
443
+ # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
444
+ out = torch.cat([
445
+ zero_offset, final_offset_left, zero_offset, final_offset_right,
446
+ final_offset_up, zero_offset, final_offset_down, zero_offset
447
+ ],
448
+ dim=2)
449
+
450
+ return out
inpainter/model/modules/spectral_norm.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spectral Normalization from https://arxiv.org/abs/1802.05957
3
+ """
4
+ import torch
5
+ from torch.nn.functional import normalize
6
+
7
+
8
+ class SpectralNorm(object):
9
+ # Invariant before and after each forward call:
10
+ # u = normalize(W @ v)
11
+ # NB: At initialization, this invariant is not enforced
12
+
13
+ _version = 1
14
+
15
+ # At version 1:
16
+ # made `W` not a buffer,
17
+ # added `v` as a buffer, and
18
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
19
+
20
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
21
+ self.name = name
22
+ self.dim = dim
23
+ if n_power_iterations <= 0:
24
+ raise ValueError(
25
+ 'Expected n_power_iterations to be positive, but '
26
+ 'got n_power_iterations={}'.format(n_power_iterations))
27
+ self.n_power_iterations = n_power_iterations
28
+ self.eps = eps
29
+
30
+ def reshape_weight_to_matrix(self, weight):
31
+ weight_mat = weight
32
+ if self.dim != 0:
33
+ # permute dim to front
34
+ weight_mat = weight_mat.permute(
35
+ self.dim,
36
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
37
+ height = weight_mat.size(0)
38
+ return weight_mat.reshape(height, -1)
39
+
40
+ def compute_weight(self, module, do_power_iteration):
41
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
42
+ # updated in power iteration **in-place**. This is very important
43
+ # because in `DataParallel` forward, the vectors (being buffers) are
44
+ # broadcast from the parallelized module to each module replica,
45
+ # which is a new module object created on the fly. And each replica
46
+ # runs its own spectral norm power iteration. So simply assigning
47
+ # the updated vectors to the module this function runs on will cause
48
+ # the update to be lost forever. And the next time the parallelized
49
+ # module is replicated, the same randomly initialized vectors are
50
+ # broadcast and used!
51
+ #
52
+ # Therefore, to make the change propagate back, we rely on two
53
+ # important behaviors (also enforced via tests):
54
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
55
+ # is already on correct device; and it makes sure that the
56
+ # parallelized module is already on `device[0]`.
57
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
58
+ # just fill in the values.
59
+ # Therefore, since the same power iteration is performed on all
60
+ # devices, simply updating the tensors in-place will make sure that
61
+ # the module replica on `device[0]` will update the _u vector on the
62
+ # parallized module (by shared storage).
63
+ #
64
+ # However, after we update `u` and `v` in-place, we need to **clone**
65
+ # them before using them to normalize the weight. This is to support
66
+ # backproping through two forward passes, e.g., the common pattern in
67
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
68
+ # complain that variables needed to do backward for the first forward
69
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
70
+ weight = getattr(module, self.name + '_orig')
71
+ u = getattr(module, self.name + '_u')
72
+ v = getattr(module, self.name + '_v')
73
+ weight_mat = self.reshape_weight_to_matrix(weight)
74
+
75
+ if do_power_iteration:
76
+ with torch.no_grad():
77
+ for _ in range(self.n_power_iterations):
78
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
79
+ # are the first left and right singular vectors.
80
+ # This power iteration produces approximations of `u` and `v`.
81
+ v = normalize(torch.mv(weight_mat.t(), u),
82
+ dim=0,
83
+ eps=self.eps,
84
+ out=v)
85
+ u = normalize(torch.mv(weight_mat, v),
86
+ dim=0,
87
+ eps=self.eps,
88
+ out=u)
89
+ if self.n_power_iterations > 0:
90
+ # See above on why we need to clone
91
+ u = u.clone()
92
+ v = v.clone()
93
+
94
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
95
+ weight = weight / sigma
96
+ return weight
97
+
98
+ def remove(self, module):
99
+ with torch.no_grad():
100
+ weight = self.compute_weight(module, do_power_iteration=False)
101
+ delattr(module, self.name)
102
+ delattr(module, self.name + '_u')
103
+ delattr(module, self.name + '_v')
104
+ delattr(module, self.name + '_orig')
105
+ module.register_parameter(self.name,
106
+ torch.nn.Parameter(weight.detach()))
107
+
108
+ def __call__(self, module, inputs):
109
+ setattr(
110
+ module, self.name,
111
+ self.compute_weight(module, do_power_iteration=module.training))
112
+
113
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
114
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
115
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
116
+ # This uses pinverse in case W^T W is not invertible.
117
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
118
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
119
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
120
+
121
+ @staticmethod
122
+ def apply(module, name, n_power_iterations, dim, eps):
123
+ for k, hook in module._forward_pre_hooks.items():
124
+ if isinstance(hook, SpectralNorm) and hook.name == name:
125
+ raise RuntimeError(
126
+ "Cannot register two spectral_norm hooks on "
127
+ "the same parameter {}".format(name))
128
+
129
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
130
+ weight = module._parameters[name]
131
+
132
+ with torch.no_grad():
133
+ weight_mat = fn.reshape_weight_to_matrix(weight)
134
+
135
+ h, w = weight_mat.size()
136
+ # randomly initialize `u` and `v`
137
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
138
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
139
+
140
+ delattr(module, fn.name)
141
+ module.register_parameter(fn.name + "_orig", weight)
142
+ # We still need to assign weight back as fn.name because all sorts of
143
+ # things may assume that it exists, e.g., when initializing weights.
144
+ # However, we can't directly assign as it could be an nn.Parameter and
145
+ # gets added as a parameter. Instead, we register weight.data as a plain
146
+ # attribute.
147
+ setattr(module, fn.name, weight.data)
148
+ module.register_buffer(fn.name + "_u", u)
149
+ module.register_buffer(fn.name + "_v", v)
150
+
151
+ module.register_forward_pre_hook(fn)
152
+
153
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
154
+ module._register_load_state_dict_pre_hook(
155
+ SpectralNormLoadStateDictPreHook(fn))
156
+ return fn
157
+
158
+
159
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
160
+ # instancemethod.
161
+ class SpectralNormLoadStateDictPreHook(object):
162
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
163
+ def __init__(self, fn):
164
+ self.fn = fn
165
+
166
+ # For state_dict with version None, (assuming that it has gone through at
167
+ # least one training forward), we have
168
+ #
169
+ # u = normalize(W_orig @ v)
170
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
171
+ #
172
+ # To compute `v`, we solve `W_orig @ x = u`, and let
173
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
174
+ def __call__(self, state_dict, prefix, local_metadata, strict,
175
+ missing_keys, unexpected_keys, error_msgs):
176
+ fn = self.fn
177
+ version = local_metadata.get('spectral_norm',
178
+ {}).get(fn.name + '.version', None)
179
+ if version is None or version < 1:
180
+ with torch.no_grad():
181
+ weight_orig = state_dict[prefix + fn.name + '_orig']
182
+ # weight = state_dict.pop(prefix + fn.name)
183
+ # sigma = (weight_orig / weight).mean()
184
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
185
+ u = state_dict[prefix + fn.name + '_u']
186
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
187
+ # state_dict[prefix + fn.name + '_v'] = v
188
+
189
+
190
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
191
+ # instancemethod.
192
+ class SpectralNormStateDictHook(object):
193
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
194
+ def __init__(self, fn):
195
+ self.fn = fn
196
+
197
+ def __call__(self, module, state_dict, prefix, local_metadata):
198
+ if 'spectral_norm' not in local_metadata:
199
+ local_metadata['spectral_norm'] = {}
200
+ key = self.fn.name + '.version'
201
+ if key in local_metadata['spectral_norm']:
202
+ raise RuntimeError(
203
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
204
+ local_metadata['spectral_norm'][key] = self.fn._version
205
+
206
+
207
+ def spectral_norm(module,
208
+ name='weight',
209
+ n_power_iterations=1,
210
+ eps=1e-12,
211
+ dim=None):
212
+ r"""Applies spectral normalization to a parameter in the given module.
213
+
214
+ .. math::
215
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
216
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
217
+
218
+ Spectral normalization stabilizes the training of discriminators (critics)
219
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
220
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
221
+ power iteration method. If the dimension of the weight tensor is greater
222
+ than 2, it is reshaped to 2D in power iteration method to get spectral
223
+ norm. This is implemented via a hook that calculates spectral norm and
224
+ rescales weight before every :meth:`~Module.forward` call.
225
+
226
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
227
+
228
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
229
+
230
+ Args:
231
+ module (nn.Module): containing module
232
+ name (str, optional): name of weight parameter
233
+ n_power_iterations (int, optional): number of power iterations to
234
+ calculate spectral norm
235
+ eps (float, optional): epsilon for numerical stability in
236
+ calculating norms
237
+ dim (int, optional): dimension corresponding to number of outputs,
238
+ the default is ``0``, except for modules that are instances of
239
+ ConvTranspose{1,2,3}d, when it is ``1``
240
+
241
+ Returns:
242
+ The original module with the spectral norm hook
243
+
244
+ Example::
245
+
246
+ >>> m = spectral_norm(nn.Linear(20, 40))
247
+ >>> m
248
+ Linear(in_features=20, out_features=40, bias=True)
249
+ >>> m.weight_u.size()
250
+ torch.Size([40])
251
+
252
+ """
253
+ if dim is None:
254
+ if isinstance(module,
255
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
256
+ torch.nn.ConvTranspose3d)):
257
+ dim = 1
258
+ else:
259
+ dim = 0
260
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
261
+ return module
262
+
263
+
264
+ def remove_spectral_norm(module, name='weight'):
265
+ r"""Removes the spectral normalization reparameterization from a module.
266
+
267
+ Args:
268
+ module (Module): containing module
269
+ name (str, optional): name of weight parameter
270
+
271
+ Example:
272
+ >>> m = spectral_norm(nn.Linear(40, 10))
273
+ >>> remove_spectral_norm(m)
274
+ """
275
+ for k, hook in module._forward_pre_hooks.items():
276
+ if isinstance(hook, SpectralNorm) and hook.name == name:
277
+ hook.remove(module)
278
+ del module._forward_pre_hooks[k]
279
+ return module
280
+
281
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
282
+ name, module))
283
+
284
+
285
+ def use_spectral_norm(module, use_sn=False):
286
+ if use_sn:
287
+ return spectral_norm(module)
288
+ return module
inpainter/model/modules/tfocal_transformer.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on:
3
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
4
+ https://github.com/ruiliu-ai/FuseFormer
5
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
6
+ https://github.com/yitu-opensource/T2T-ViT
7
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
8
+ https://github.com/microsoft/Focal-Transformer
9
+ """
10
+
11
+ import math
12
+ from functools import reduce
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class SoftSplit(nn.Module):
20
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
21
+ t2t_param):
22
+ super(SoftSplit, self).__init__()
23
+ self.kernel_size = kernel_size
24
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
25
+ stride=stride,
26
+ padding=padding)
27
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
28
+ self.embedding = nn.Linear(c_in, hidden)
29
+
30
+ self.f_h = int(
31
+ (t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
32
+ (t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
33
+ 1)
34
+ self.f_w = int(
35
+ (t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
36
+ (t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
37
+ 1)
38
+
39
+ def forward(self, x, b):
40
+ feat = self.t2t(x)
41
+ feat = feat.permute(0, 2, 1)
42
+ # feat shape [b*t, num_vec, ks*ks*c]
43
+ feat = self.embedding(feat)
44
+ # feat shape after embedding [b, t*num_vec, hidden]
45
+ feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
46
+ return feat
47
+
48
+
49
+ class SoftComp(nn.Module):
50
+ def __init__(self, channel, hidden, output_size, kernel_size, stride,
51
+ padding):
52
+ super(SoftComp, self).__init__()
53
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
54
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
55
+ self.embedding = nn.Linear(hidden, c_out)
56
+ self.t2t = torch.nn.Fold(output_size=output_size,
57
+ kernel_size=kernel_size,
58
+ stride=stride,
59
+ padding=padding)
60
+ h, w = output_size
61
+ self.bias = nn.Parameter(torch.zeros((channel, h, w),
62
+ dtype=torch.float32),
63
+ requires_grad=True)
64
+
65
+ def forward(self, x, t):
66
+ b_, _, _, _, c_ = x.shape
67
+ x = x.view(b_, -1, c_)
68
+ feat = self.embedding(x)
69
+ b, _, c = feat.size()
70
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
71
+ feat = self.t2t(feat) + self.bias[None]
72
+ return feat
73
+
74
+
75
+ class FusionFeedForward(nn.Module):
76
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
77
+ super(FusionFeedForward, self).__init__()
78
+ # We set d_ff as a default to 1960
79
+ hd = 1960
80
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
81
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
82
+ assert t2t_params is not None and n_vecs is not None
83
+ tp = t2t_params.copy()
84
+ self.fold = nn.Fold(**tp)
85
+ del tp['output_size']
86
+ self.unfold = nn.Unfold(**tp)
87
+ self.n_vecs = n_vecs
88
+
89
+ def forward(self, x):
90
+ x = self.conv1(x)
91
+ b, n, c = x.size()
92
+ normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
93
+ 49).permute(0, 2, 1)
94
+ x = self.unfold(
95
+ self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
96
+ self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
97
+ x = self.conv2(x)
98
+ return x
99
+
100
+
101
+ def window_partition(x, window_size):
102
+ """
103
+ Args:
104
+ x: shape is (B, T, H, W, C)
105
+ window_size (tuple[int]): window size
106
+ Returns:
107
+ windows: (B*num_windows, T*window_size*window_size, C)
108
+ """
109
+ B, T, H, W, C = x.shape
110
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
111
+ window_size[1], C)
112
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
113
+ -1, T * window_size[0] * window_size[1], C)
114
+ return windows
115
+
116
+
117
+ def window_partition_noreshape(x, window_size):
118
+ """
119
+ Args:
120
+ x: shape is (B, T, H, W, C)
121
+ window_size (tuple[int]): window size
122
+ Returns:
123
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
124
+ """
125
+ B, T, H, W, C = x.shape
126
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
127
+ window_size[1], C)
128
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
129
+ return windows
130
+
131
+
132
+ def window_reverse(windows, window_size, T, H, W):
133
+ """
134
+ Args:
135
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
136
+ window_size (tuple[int]): Window size
137
+ T (int): Temporal length of video
138
+ H (int): Height of image
139
+ W (int): Width of image
140
+ Returns:
141
+ x: (B, T, H, W, C)
142
+ """
143
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
144
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
145
+ window_size[0], window_size[1], -1)
146
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
147
+ return x
148
+
149
+
150
+ class WindowAttention(nn.Module):
151
+ """Temporal focal window attention
152
+ """
153
+ def __init__(self, dim, expand_size, window_size, focal_window,
154
+ focal_level, num_heads, qkv_bias, pool_method):
155
+
156
+ super().__init__()
157
+ self.dim = dim
158
+ self.expand_size = expand_size
159
+ self.window_size = window_size # Wh, Ww
160
+ self.pool_method = pool_method
161
+ self.num_heads = num_heads
162
+ head_dim = dim // num_heads
163
+ self.scale = head_dim**-0.5
164
+ self.focal_level = focal_level
165
+ self.focal_window = focal_window
166
+
167
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
168
+ # get mask for rolled k and rolled v
169
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
170
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
171
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
172
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
173
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
174
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
175
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
176
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
177
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
178
+ 0).flatten(0)
179
+ self.register_buffer("valid_ind_rolled",
180
+ mask_rolled.nonzero(as_tuple=False).view(-1))
181
+
182
+ if pool_method != "none" and focal_level > 1:
183
+ self.unfolds = nn.ModuleList()
184
+
185
+ # build relative position bias between local patch and pooled windows
186
+ for k in range(focal_level - 1):
187
+ stride = 2**k
188
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
189
+ for i in self.focal_window)
190
+ # define unfolding operations
191
+ self.unfolds += [
192
+ nn.Unfold(kernel_size=kernel_size,
193
+ stride=stride,
194
+ padding=tuple(i // 2 for i in kernel_size))
195
+ ]
196
+
197
+ # define unfolding index for focal_level > 0
198
+ if k > 0:
199
+ mask = torch.zeros(kernel_size)
200
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
201
+ self.register_buffer(
202
+ "valid_ind_unfold_{}".format(k),
203
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
204
+
205
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
206
+ self.proj = nn.Linear(dim, dim)
207
+
208
+ self.softmax = nn.Softmax(dim=-1)
209
+
210
+ def forward(self, x_all, mask_all=None):
211
+ """
212
+ Args:
213
+ x: input features with shape of (B, T, Wh, Ww, C)
214
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
215
+
216
+ output: (nW*B, Wh*Ww, C)
217
+ """
218
+ x = x_all[0]
219
+
220
+ B, T, nH, nW, C = x.shape
221
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
222
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
223
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
224
+
225
+ # partition q map
226
+ (q_windows, k_windows, v_windows) = map(
227
+ lambda t: window_partition(t, self.window_size).view(
228
+ -1, T, self.window_size[0] * self.window_size[1], self.
229
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
230
+ contiguous().view(-1, self.num_heads, T * self.window_size[
231
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
232
+ # q(k/v)_windows shape : [16, 4, 225, 128]
233
+
234
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
235
+ (k_tl, v_tl) = map(
236
+ lambda t: torch.roll(t,
237
+ shifts=(-self.expand_size[0], -self.
238
+ expand_size[1]),
239
+ dims=(2, 3)), (k, v))
240
+ (k_tr, v_tr) = map(
241
+ lambda t: torch.roll(t,
242
+ shifts=(-self.expand_size[0], self.
243
+ expand_size[1]),
244
+ dims=(2, 3)), (k, v))
245
+ (k_bl, v_bl) = map(
246
+ lambda t: torch.roll(t,
247
+ shifts=(self.expand_size[0], -self.
248
+ expand_size[1]),
249
+ dims=(2, 3)), (k, v))
250
+ (k_br, v_br) = map(
251
+ lambda t: torch.roll(t,
252
+ shifts=(self.expand_size[0], self.
253
+ expand_size[1]),
254
+ dims=(2, 3)), (k, v))
255
+
256
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
257
+ lambda t: window_partition(t, self.window_size).view(
258
+ -1, T, self.window_size[0] * self.window_size[1], self.
259
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
260
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
261
+ lambda t: window_partition(t, self.window_size).view(
262
+ -1, T, self.window_size[0] * self.window_size[1], self.
263
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
264
+ k_rolled = torch.cat(
265
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
266
+ 2).permute(0, 3, 1, 2, 4).contiguous()
267
+ v_rolled = torch.cat(
268
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
269
+ 2).permute(0, 3, 1, 2, 4).contiguous()
270
+
271
+ # mask out tokens in current window
272
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
273
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
274
+ temp_N = k_rolled.shape[3]
275
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
276
+ C // self.num_heads)
277
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
278
+ C // self.num_heads)
279
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
280
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
281
+ else:
282
+ k_rolled = k_windows
283
+ v_rolled = v_windows
284
+
285
+ # q(k/v)_windows shape : [16, 4, 225, 128]
286
+ # k_rolled.shape : [16, 4, 5, 165, 128]
287
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
288
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
289
+
290
+ if self.pool_method != "none" and self.focal_level > 1:
291
+ k_pooled = []
292
+ v_pooled = []
293
+ for k in range(self.focal_level - 1):
294
+ stride = 2**k
295
+ x_window_pooled = x_all[k + 1].permute(
296
+ 0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
297
+
298
+ nWh, nWw = x_window_pooled.shape[2:4]
299
+
300
+ # generate mask for pooled windows
301
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
302
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
303
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
304
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
305
+ view(nWh*nWw // stride // stride, -1, 1)
306
+
307
+ if k > 0:
308
+ valid_ind_unfold_k = getattr(
309
+ self, "valid_ind_unfold_{}".format(k))
310
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
311
+
312
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
313
+ x_window_masks = x_window_masks.masked_fill(
314
+ x_window_masks == 0,
315
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
316
+ mask_all[k + 1] = x_window_masks
317
+
318
+ # generate k and v for pooled windows
319
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
320
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
321
+ 3).view(3, -1, C, nWh,
322
+ nWw).contiguous()
323
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
324
+ 2] # B*T, C, nWh, nWw
325
+ # k_pooled_k shape: [5, 512, 4, 4]
326
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
327
+
328
+ (k_pooled_k, v_pooled_k) = map(
329
+ lambda t: self.unfolds[k](t).view(
330
+ B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
331
+ view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
332
+ (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
333
+ )
334
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
335
+
336
+ # select valid unfolding index
337
+ if k > 0:
338
+ (k_pooled_k, v_pooled_k) = map(
339
+ lambda t: t[:, :, :, valid_ind_unfold_k],
340
+ (k_pooled_k, v_pooled_k))
341
+
342
+ k_pooled_k = k_pooled_k.view(
343
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
344
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
345
+ v_pooled_k = v_pooled_k.view(
346
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
347
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
348
+
349
+ k_pooled += [k_pooled_k]
350
+ v_pooled += [v_pooled_k]
351
+
352
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
353
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
354
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
355
+ else:
356
+ k_all = k_rolled
357
+ v_all = v_rolled
358
+
359
+ N = k_all.shape[-2]
360
+ q_windows = q_windows * self.scale
361
+ attn = (
362
+ q_windows @ k_all.transpose(-2, -1)
363
+ ) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
364
+ # T * 45
365
+ window_area = T * self.window_size[0] * self.window_size[1]
366
+ # T * 165
367
+ window_area_rolled = k_rolled.shape[2]
368
+
369
+ if self.pool_method != "none" and self.focal_level > 1:
370
+ offset = window_area_rolled
371
+ for k in range(self.focal_level - 1):
372
+ # add attentional mask
373
+ # mask_all[1] shape [1, 16, T * 45]
374
+
375
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
376
+
377
+ if mask_all[k + 1] is not None:
378
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
379
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
380
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
381
+
382
+ offset += T * bias[0] * bias[1]
383
+
384
+ if mask_all[0] is not None:
385
+ nW = mask_all[0].shape[0]
386
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
387
+ window_area, N)
388
+ attn[:, :, :, :, :
389
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
390
+ None, :, None, :, :]
391
+ attn = attn.view(-1, self.num_heads, window_area, N)
392
+ attn = self.softmax(attn)
393
+ else:
394
+ attn = self.softmax(attn)
395
+
396
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
397
+ C)
398
+ x = self.proj(x)
399
+ return x
400
+
401
+
402
+ class TemporalFocalTransformerBlock(nn.Module):
403
+ r""" Temporal Focal Transformer Block.
404
+ Args:
405
+ dim (int): Number of input channels.
406
+ num_heads (int): Number of attention heads.
407
+ window_size (tuple[int]): Window size.
408
+ shift_size (int): Shift size for SW-MSA.
409
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
410
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
411
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
412
+ focal_level (int): The number level of focal window.
413
+ focal_window (int): Window size of each focal window.
414
+ n_vecs (int): Required for F3N.
415
+ t2t_params (int): T2T parameters for F3N.
416
+ """
417
+ def __init__(self,
418
+ dim,
419
+ num_heads,
420
+ window_size=(5, 9),
421
+ mlp_ratio=4.,
422
+ qkv_bias=True,
423
+ pool_method="fc",
424
+ focal_level=2,
425
+ focal_window=(5, 9),
426
+ norm_layer=nn.LayerNorm,
427
+ n_vecs=None,
428
+ t2t_params=None):
429
+ super().__init__()
430
+ self.dim = dim
431
+ self.num_heads = num_heads
432
+ self.window_size = window_size
433
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
434
+ self.mlp_ratio = mlp_ratio
435
+ self.pool_method = pool_method
436
+ self.focal_level = focal_level
437
+ self.focal_window = focal_window
438
+
439
+ self.window_size_glo = self.window_size
440
+
441
+ self.pool_layers = nn.ModuleList()
442
+ if self.pool_method != "none":
443
+ for k in range(self.focal_level - 1):
444
+ window_size_glo = tuple(
445
+ math.floor(i / (2**k)) for i in self.window_size_glo)
446
+ self.pool_layers.append(
447
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
448
+ self.pool_layers[-1].weight.data.fill_(
449
+ 1. / (window_size_glo[0] * window_size_glo[1]))
450
+ self.pool_layers[-1].bias.data.fill_(0)
451
+
452
+ self.norm1 = norm_layer(dim)
453
+
454
+ self.attn = WindowAttention(dim,
455
+ expand_size=self.expand_size,
456
+ window_size=self.window_size,
457
+ focal_window=focal_window,
458
+ focal_level=focal_level,
459
+ num_heads=num_heads,
460
+ qkv_bias=qkv_bias,
461
+ pool_method=pool_method)
462
+
463
+ self.norm2 = norm_layer(dim)
464
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
465
+
466
+ def forward(self, x):
467
+ B, T, H, W, C = x.shape
468
+
469
+ shortcut = x
470
+ x = self.norm1(x)
471
+
472
+ shifted_x = x
473
+
474
+ x_windows_all = [shifted_x]
475
+ x_window_masks_all = [None]
476
+
477
+ # partition windows tuple(i // 2 for i in window_size)
478
+ if self.focal_level > 1 and self.pool_method != "none":
479
+ # if we add coarser granularity and the pool method is not none
480
+ for k in range(self.focal_level - 1):
481
+ window_size_glo = tuple(
482
+ math.floor(i / (2**k)) for i in self.window_size_glo)
483
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
484
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
485
+ H_pool = pooled_h * window_size_glo[0]
486
+ W_pool = pooled_w * window_size_glo[1]
487
+
488
+ x_level_k = shifted_x
489
+ # trim or pad shifted_x depending on the required size
490
+ if H > H_pool:
491
+ trim_t = (H - H_pool) // 2
492
+ trim_b = H - H_pool - trim_t
493
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
494
+ elif H < H_pool:
495
+ pad_t = (H_pool - H) // 2
496
+ pad_b = H_pool - H - pad_t
497
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
498
+
499
+ if W > W_pool:
500
+ trim_l = (W - W_pool) // 2
501
+ trim_r = W - W_pool - trim_l
502
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
503
+ elif W < W_pool:
504
+ pad_l = (W_pool - W) // 2
505
+ pad_r = W_pool - W - pad_l
506
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
507
+
508
+ x_windows_noreshape = window_partition_noreshape(
509
+ x_level_k.contiguous(), window_size_glo
510
+ ) # B, nw, nw, T, window_size, window_size, C
511
+ nWh, nWw = x_windows_noreshape.shape[1:3]
512
+ x_windows_noreshape = x_windows_noreshape.view(
513
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
514
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
515
+ x_windows_pooled = self.pool_layers[k](
516
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
517
+
518
+ x_windows_all += [x_windows_pooled]
519
+ x_window_masks_all += [None]
520
+
521
+ attn_windows = self.attn(
522
+ x_windows_all,
523
+ mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
524
+
525
+ # merge windows
526
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
527
+ self.window_size[1], C)
528
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
529
+ W) # B T H' W' C
530
+
531
+ # FFN
532
+ x = shortcut + shifted_x
533
+ y = self.norm2(x)
534
+ x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
535
+
536
+ return x
inpainter/model/modules/tfocal_transformer_hq.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on:
3
+ [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
4
+ https://github.com/ruiliu-ai/FuseFormer
5
+ [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
6
+ https://github.com/yitu-opensource/T2T-ViT
7
+ [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
8
+ https://github.com/microsoft/Focal-Transformer
9
+ """
10
+
11
+ import math
12
+ from functools import reduce
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class SoftSplit(nn.Module):
20
+ def __init__(self, channel, hidden, kernel_size, stride, padding,
21
+ t2t_param):
22
+ super(SoftSplit, self).__init__()
23
+ self.kernel_size = kernel_size
24
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
25
+ stride=stride,
26
+ padding=padding)
27
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
28
+ self.embedding = nn.Linear(c_in, hidden)
29
+
30
+ self.t2t_param = t2t_param
31
+
32
+ def forward(self, x, b, output_size):
33
+ f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
34
+ (self.t2t_param['kernel_size'][0] - 1) - 1) /
35
+ self.t2t_param['stride'][0] + 1)
36
+ f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
37
+ (self.t2t_param['kernel_size'][1] - 1) - 1) /
38
+ self.t2t_param['stride'][1] + 1)
39
+
40
+ feat = self.t2t(x)
41
+ feat = feat.permute(0, 2, 1)
42
+ # feat shape [b*t, num_vec, ks*ks*c]
43
+ feat = self.embedding(feat)
44
+ # feat shape after embedding [b, t*num_vec, hidden]
45
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
46
+ return feat
47
+
48
+
49
+ class SoftComp(nn.Module):
50
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
51
+ super(SoftComp, self).__init__()
52
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
53
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
54
+ self.embedding = nn.Linear(hidden, c_out)
55
+ self.kernel_size = kernel_size
56
+ self.stride = stride
57
+ self.padding = padding
58
+ self.bias_conv = nn.Conv2d(channel,
59
+ channel,
60
+ kernel_size=3,
61
+ stride=1,
62
+ padding=1)
63
+ # TODO upsample conv
64
+ # self.bias_conv = nn.Conv2d()
65
+ # self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
66
+
67
+ def forward(self, x, t, output_size):
68
+ b_, _, _, _, c_ = x.shape
69
+ x = x.view(b_, -1, c_)
70
+ feat = self.embedding(x)
71
+ b, _, c = feat.size()
72
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
73
+ feat = F.fold(feat,
74
+ output_size=output_size,
75
+ kernel_size=self.kernel_size,
76
+ stride=self.stride,
77
+ padding=self.padding)
78
+ feat = self.bias_conv(feat)
79
+ return feat
80
+
81
+
82
+ class FusionFeedForward(nn.Module):
83
+ def __init__(self, d_model, n_vecs=None, t2t_params=None):
84
+ super(FusionFeedForward, self).__init__()
85
+ # We set d_ff as a default to 1960
86
+ hd = 1960
87
+ self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
88
+ self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
89
+ assert t2t_params is not None and n_vecs is not None
90
+ self.t2t_params = t2t_params
91
+
92
+ def forward(self, x, output_size):
93
+ n_vecs = 1
94
+ for i, d in enumerate(self.t2t_params['kernel_size']):
95
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
96
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
97
+
98
+ x = self.conv1(x)
99
+ b, n, c = x.size()
100
+ normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
101
+ normalizer = F.fold(normalizer,
102
+ output_size=output_size,
103
+ kernel_size=self.t2t_params['kernel_size'],
104
+ padding=self.t2t_params['padding'],
105
+ stride=self.t2t_params['stride'])
106
+
107
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
108
+ output_size=output_size,
109
+ kernel_size=self.t2t_params['kernel_size'],
110
+ padding=self.t2t_params['padding'],
111
+ stride=self.t2t_params['stride'])
112
+
113
+ x = F.unfold(x / normalizer,
114
+ kernel_size=self.t2t_params['kernel_size'],
115
+ padding=self.t2t_params['padding'],
116
+ stride=self.t2t_params['stride']).permute(
117
+ 0, 2, 1).contiguous().view(b, n, c)
118
+ x = self.conv2(x)
119
+ return x
120
+
121
+
122
+ def window_partition(x, window_size):
123
+ """
124
+ Args:
125
+ x: shape is (B, T, H, W, C)
126
+ window_size (tuple[int]): window size
127
+ Returns:
128
+ windows: (B*num_windows, T*window_size*window_size, C)
129
+ """
130
+ B, T, H, W, C = x.shape
131
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
132
+ window_size[1], C)
133
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
134
+ -1, T * window_size[0] * window_size[1], C)
135
+ return windows
136
+
137
+
138
+ def window_partition_noreshape(x, window_size):
139
+ """
140
+ Args:
141
+ x: shape is (B, T, H, W, C)
142
+ window_size (tuple[int]): window size
143
+ Returns:
144
+ windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
145
+ """
146
+ B, T, H, W, C = x.shape
147
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
148
+ window_size[1], C)
149
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
150
+ return windows
151
+
152
+
153
+ def window_reverse(windows, window_size, T, H, W):
154
+ """
155
+ Args:
156
+ windows: shape is (num_windows*B, T, window_size, window_size, C)
157
+ window_size (tuple[int]): Window size
158
+ T (int): Temporal length of video
159
+ H (int): Height of image
160
+ W (int): Width of image
161
+ Returns:
162
+ x: (B, T, H, W, C)
163
+ """
164
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
165
+ x = windows.view(B, H // window_size[0], W // window_size[1], T,
166
+ window_size[0], window_size[1], -1)
167
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
168
+ return x
169
+
170
+
171
+ class WindowAttention(nn.Module):
172
+ """Temporal focal window attention
173
+ """
174
+ def __init__(self, dim, expand_size, window_size, focal_window,
175
+ focal_level, num_heads, qkv_bias, pool_method):
176
+
177
+ super().__init__()
178
+ self.dim = dim
179
+ self.expand_size = expand_size
180
+ self.window_size = window_size # Wh, Ww
181
+ self.pool_method = pool_method
182
+ self.num_heads = num_heads
183
+ head_dim = dim // num_heads
184
+ self.scale = head_dim**-0.5
185
+ self.focal_level = focal_level
186
+ self.focal_window = focal_window
187
+
188
+ if any(i > 0 for i in self.expand_size) and focal_level > 0:
189
+ # get mask for rolled k and rolled v
190
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
191
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
192
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
193
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
194
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
195
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
196
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
197
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
198
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
199
+ 0).flatten(0)
200
+ self.register_buffer("valid_ind_rolled",
201
+ mask_rolled.nonzero(as_tuple=False).view(-1))
202
+
203
+ if pool_method != "none" and focal_level > 1:
204
+ self.unfolds = nn.ModuleList()
205
+
206
+ # build relative position bias between local patch and pooled windows
207
+ for k in range(focal_level - 1):
208
+ stride = 2**k
209
+ kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
210
+ for i in self.focal_window)
211
+ # define unfolding operations
212
+ self.unfolds += [
213
+ nn.Unfold(kernel_size=kernel_size,
214
+ stride=stride,
215
+ padding=tuple(i // 2 for i in kernel_size))
216
+ ]
217
+
218
+ # define unfolding index for focal_level > 0
219
+ if k > 0:
220
+ mask = torch.zeros(kernel_size)
221
+ mask[(2**k) - 1:, (2**k) - 1:] = 1
222
+ self.register_buffer(
223
+ "valid_ind_unfold_{}".format(k),
224
+ mask.flatten(0).nonzero(as_tuple=False).view(-1))
225
+
226
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
227
+ self.proj = nn.Linear(dim, dim)
228
+
229
+ self.softmax = nn.Softmax(dim=-1)
230
+
231
+ def forward(self, x_all, mask_all=None):
232
+ """
233
+ Args:
234
+ x: input features with shape of (B, T, Wh, Ww, C)
235
+ mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
236
+
237
+ output: (nW*B, Wh*Ww, C)
238
+ """
239
+ x = x_all[0]
240
+
241
+ B, T, nH, nW, C = x.shape
242
+ qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
243
+ C).permute(4, 0, 1, 2, 3, 5).contiguous()
244
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
245
+
246
+ # partition q map
247
+ (q_windows, k_windows, v_windows) = map(
248
+ lambda t: window_partition(t, self.window_size).view(
249
+ -1, T, self.window_size[0] * self.window_size[1], self.
250
+ num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
251
+ contiguous().view(-1, self.num_heads, T * self.window_size[
252
+ 0] * self.window_size[1], C // self.num_heads), (q, k, v))
253
+ # q(k/v)_windows shape : [16, 4, 225, 128]
254
+
255
+ if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
256
+ (k_tl, v_tl) = map(
257
+ lambda t: torch.roll(t,
258
+ shifts=(-self.expand_size[0], -self.
259
+ expand_size[1]),
260
+ dims=(2, 3)), (k, v))
261
+ (k_tr, v_tr) = map(
262
+ lambda t: torch.roll(t,
263
+ shifts=(-self.expand_size[0], self.
264
+ expand_size[1]),
265
+ dims=(2, 3)), (k, v))
266
+ (k_bl, v_bl) = map(
267
+ lambda t: torch.roll(t,
268
+ shifts=(self.expand_size[0], -self.
269
+ expand_size[1]),
270
+ dims=(2, 3)), (k, v))
271
+ (k_br, v_br) = map(
272
+ lambda t: torch.roll(t,
273
+ shifts=(self.expand_size[0], self.
274
+ expand_size[1]),
275
+ dims=(2, 3)), (k, v))
276
+
277
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
278
+ lambda t: window_partition(t, self.window_size).view(
279
+ -1, T, self.window_size[0] * self.window_size[1], self.
280
+ num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
281
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
282
+ lambda t: window_partition(t, self.window_size).view(
283
+ -1, T, self.window_size[0] * self.window_size[1], self.
284
+ num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
285
+ k_rolled = torch.cat(
286
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
287
+ 2).permute(0, 3, 1, 2, 4).contiguous()
288
+ v_rolled = torch.cat(
289
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
290
+ 2).permute(0, 3, 1, 2, 4).contiguous()
291
+
292
+ # mask out tokens in current window
293
+ k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
294
+ v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
295
+ temp_N = k_rolled.shape[3]
296
+ k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
297
+ C // self.num_heads)
298
+ v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
299
+ C // self.num_heads)
300
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
301
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
302
+ else:
303
+ k_rolled = k_windows
304
+ v_rolled = v_windows
305
+
306
+ # q(k/v)_windows shape : [16, 4, 225, 128]
307
+ # k_rolled.shape : [16, 4, 5, 165, 128]
308
+ # ideal expanded window size 153 ((5+2*2)*(9+2*4))
309
+ # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
310
+
311
+ if self.pool_method != "none" and self.focal_level > 1:
312
+ k_pooled = []
313
+ v_pooled = []
314
+ for k in range(self.focal_level - 1):
315
+ stride = 2**k
316
+ # B, T, nWh, nWw, C
317
+ x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
318
+ 4).contiguous()
319
+
320
+ nWh, nWw = x_window_pooled.shape[2:4]
321
+
322
+ # generate mask for pooled windows
323
+ mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
324
+ # unfold mask: [nWh*nWw//s//s, k*k, 1]
325
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
326
+ 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
327
+ view(nWh*nWw // stride // stride, -1, 1)
328
+
329
+ if k > 0:
330
+ valid_ind_unfold_k = getattr(
331
+ self, "valid_ind_unfold_{}".format(k))
332
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
333
+
334
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
335
+ x_window_masks = x_window_masks.masked_fill(
336
+ x_window_masks == 0,
337
+ float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
338
+ mask_all[k + 1] = x_window_masks
339
+
340
+ # generate k and v for pooled windows
341
+ qkv_pooled = self.qkv(x_window_pooled).reshape(
342
+ B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
343
+ 3).view(3, -1, C, nWh,
344
+ nWw).contiguous()
345
+ # B*T, C, nWh, nWw
346
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
347
+ # k_pooled_k shape: [5, 512, 4, 4]
348
+ # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
349
+
350
+ (k_pooled_k, v_pooled_k) = map(
351
+ lambda t: self.unfolds[k]
352
+ (t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
353
+ unfolds[k].kernel_size[1], -1)
354
+ .permute(0, 5, 1, 3, 4, 2).contiguous().view(
355
+ -1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
356
+ k].kernel_size[1], self.num_heads, C // self.
357
+ num_heads).permute(0, 3, 1, 2, 4).contiguous(),
358
+ # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
359
+ (k_pooled_k, v_pooled_k))
360
+ # k_pooled_k shape : [16, 4, 5, 45, 128]
361
+
362
+ # select valid unfolding index
363
+ if k > 0:
364
+ (k_pooled_k, v_pooled_k) = map(
365
+ lambda t: t[:, :, :, valid_ind_unfold_k],
366
+ (k_pooled_k, v_pooled_k))
367
+
368
+ k_pooled_k = k_pooled_k.view(
369
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
370
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
371
+ v_pooled_k = v_pooled_k.view(
372
+ -1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
373
+ self.unfolds[k].kernel_size[1], C // self.num_heads)
374
+
375
+ k_pooled += [k_pooled_k]
376
+ v_pooled += [v_pooled_k]
377
+
378
+ # k_all (v_all) shape : [16, 4, 5 * 210, 128]
379
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
380
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
381
+ else:
382
+ k_all = k_rolled
383
+ v_all = v_rolled
384
+
385
+ N = k_all.shape[-2]
386
+ q_windows = q_windows * self.scale
387
+ # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
388
+ attn = (q_windows @ k_all.transpose(-2, -1))
389
+ # T * 45
390
+ window_area = T * self.window_size[0] * self.window_size[1]
391
+ # T * 165
392
+ window_area_rolled = k_rolled.shape[2]
393
+
394
+ if self.pool_method != "none" and self.focal_level > 1:
395
+ offset = window_area_rolled
396
+ for k in range(self.focal_level - 1):
397
+ # add attentional mask
398
+ # mask_all[1] shape [1, 16, T * 45]
399
+
400
+ bias = tuple((i + 2**k - 1) for i in self.focal_window)
401
+
402
+ if mask_all[k + 1] is not None:
403
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
404
+ attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
405
+ mask_all[k+1][:, :, None, None, :].repeat(
406
+ attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
407
+
408
+ offset += T * bias[0] * bias[1]
409
+
410
+ if mask_all[0] is not None:
411
+ nW = mask_all[0].shape[0]
412
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
413
+ window_area, N)
414
+ attn[:, :, :, :, :
415
+ window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
416
+ None, :, None, :, :]
417
+ attn = attn.view(-1, self.num_heads, window_area, N)
418
+ attn = self.softmax(attn)
419
+ else:
420
+ attn = self.softmax(attn)
421
+
422
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
423
+ C)
424
+ x = self.proj(x)
425
+ return x
426
+
427
+
428
+ class TemporalFocalTransformerBlock(nn.Module):
429
+ r""" Temporal Focal Transformer Block.
430
+ Args:
431
+ dim (int): Number of input channels.
432
+ num_heads (int): Number of attention heads.
433
+ window_size (tuple[int]): Window size.
434
+ shift_size (int): Shift size for SW-MSA.
435
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
436
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
437
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
438
+ focal_level (int): The number level of focal window.
439
+ focal_window (int): Window size of each focal window.
440
+ n_vecs (int): Required for F3N.
441
+ t2t_params (int): T2T parameters for F3N.
442
+ """
443
+ def __init__(self,
444
+ dim,
445
+ num_heads,
446
+ window_size=(5, 9),
447
+ mlp_ratio=4.,
448
+ qkv_bias=True,
449
+ pool_method="fc",
450
+ focal_level=2,
451
+ focal_window=(5, 9),
452
+ norm_layer=nn.LayerNorm,
453
+ n_vecs=None,
454
+ t2t_params=None):
455
+ super().__init__()
456
+ self.dim = dim
457
+ self.num_heads = num_heads
458
+ self.window_size = window_size
459
+ self.expand_size = tuple(i // 2 for i in window_size) # TODO
460
+ self.mlp_ratio = mlp_ratio
461
+ self.pool_method = pool_method
462
+ self.focal_level = focal_level
463
+ self.focal_window = focal_window
464
+
465
+ self.window_size_glo = self.window_size
466
+
467
+ self.pool_layers = nn.ModuleList()
468
+ if self.pool_method != "none":
469
+ for k in range(self.focal_level - 1):
470
+ window_size_glo = tuple(
471
+ math.floor(i / (2**k)) for i in self.window_size_glo)
472
+ self.pool_layers.append(
473
+ nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
474
+ self.pool_layers[-1].weight.data.fill_(
475
+ 1. / (window_size_glo[0] * window_size_glo[1]))
476
+ self.pool_layers[-1].bias.data.fill_(0)
477
+
478
+ self.norm1 = norm_layer(dim)
479
+
480
+ self.attn = WindowAttention(dim,
481
+ expand_size=self.expand_size,
482
+ window_size=self.window_size,
483
+ focal_window=focal_window,
484
+ focal_level=focal_level,
485
+ num_heads=num_heads,
486
+ qkv_bias=qkv_bias,
487
+ pool_method=pool_method)
488
+
489
+ self.norm2 = norm_layer(dim)
490
+ self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
491
+
492
+ def forward(self, x):
493
+ output_size = x[1]
494
+ x = x[0]
495
+
496
+ B, T, H, W, C = x.shape
497
+
498
+ shortcut = x
499
+ x = self.norm1(x)
500
+
501
+ shifted_x = x
502
+
503
+ x_windows_all = [shifted_x]
504
+ x_window_masks_all = [None]
505
+
506
+ # partition windows tuple(i // 2 for i in window_size)
507
+ if self.focal_level > 1 and self.pool_method != "none":
508
+ # if we add coarser granularity and the pool method is not none
509
+ for k in range(self.focal_level - 1):
510
+ window_size_glo = tuple(
511
+ math.floor(i / (2**k)) for i in self.window_size_glo)
512
+ pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
513
+ pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
514
+ H_pool = pooled_h * window_size_glo[0]
515
+ W_pool = pooled_w * window_size_glo[1]
516
+
517
+ x_level_k = shifted_x
518
+ # trim or pad shifted_x depending on the required size
519
+ if H > H_pool:
520
+ trim_t = (H - H_pool) // 2
521
+ trim_b = H - H_pool - trim_t
522
+ x_level_k = x_level_k[:, :, trim_t:-trim_b]
523
+ elif H < H_pool:
524
+ pad_t = (H_pool - H) // 2
525
+ pad_b = H_pool - H - pad_t
526
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
527
+
528
+ if W > W_pool:
529
+ trim_l = (W - W_pool) // 2
530
+ trim_r = W - W_pool - trim_l
531
+ x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
532
+ elif W < W_pool:
533
+ pad_l = (W_pool - W) // 2
534
+ pad_r = W_pool - W - pad_l
535
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
536
+
537
+ x_windows_noreshape = window_partition_noreshape(
538
+ x_level_k.contiguous(), window_size_glo
539
+ ) # B, nw, nw, T, window_size, window_size, C
540
+ nWh, nWw = x_windows_noreshape.shape[1:3]
541
+ x_windows_noreshape = x_windows_noreshape.view(
542
+ B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
543
+ C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
544
+ x_windows_pooled = self.pool_layers[k](
545
+ x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
546
+
547
+ x_windows_all += [x_windows_pooled]
548
+ x_window_masks_all += [None]
549
+
550
+ # nW*B, T*window_size*window_size, C
551
+ attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
552
+
553
+ # merge windows
554
+ attn_windows = attn_windows.view(-1, T, self.window_size[0],
555
+ self.window_size[1], C)
556
+ shifted_x = window_reverse(attn_windows, self.window_size, T, H,
557
+ W) # B T H' W' C
558
+
559
+ # FFN
560
+ x = shortcut + shifted_x
561
+ y = self.norm2(x)
562
+ x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
563
+ B, T, H, W, C)
564
+
565
+ return x, output_size
inpainter/util/__init__.py ADDED
File without changes
inpainter/util/tensor_util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # resize frames
5
+ def resize_frames(frames, size=None):
6
+ """
7
+ size: (w, h)
8
+ """
9
+ if size is not None:
10
+ frames = [cv2.resize(f, size) for f in frames]
11
+ frames = np.stack(frames, 0)
12
+
13
+ return frames
14
+
15
+ # resize frames
16
+ def resize_masks(masks, size=None):
17
+ """
18
+ size: (w, h)
19
+ """
20
+ if size is not None:
21
+ masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
22
+ masks = np.stack(masks, 0)
23
+
24
+ return masks
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progressbar2
2
+ gdown
3
+ gitpython
4
+ git+https://github.com/cheind/py-thin-plate-spline
5
+ hickle
6
+ tensorboard
7
+ numpy
8
+ git+https://github.com/facebookresearch/segment-anything.git
9
+ gradio==3.25.0
10
+ opencv-python
11
+ pycocotools
12
+ matplotlib
13
+ onnxruntime
14
+ onnx
15
+ metaseg
16
+ pyyaml
17
+ av
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
template.html ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- template.html -->
2
+ <!DOCTYPE html>
3
+ <html lang="en">
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Gradio Video Pause Time</title>
8
+ </head>
9
+ <body>
10
+ <video id="video" controls>
11
+ <source src="{{VIDEO_URL}}" type="video/mp4">
12
+ Your browser does not support the video tag.
13
+ </video>
14
+ <script>
15
+ const video = document.getElementById("video");
16
+ let pauseTime = null;
17
+
18
+ video.addEventListener("pause", () => {
19
+ pauseTime = video.currentTime;
20
+ });
21
+
22
+ function getPauseTime() {
23
+ return pauseTime;
24
+ }
25
+ </script>
26
+ </body>
27
+ </html>
templates/index.html ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Video Object Segmentation</title>
8
+ <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
9
+ </head>
10
+ <body>
11
+ <h1>Video Object Segmentation</h1>
12
+
13
+ <input type="file" id="video-input" accept="video/*">
14
+ <button id="upload-video">Upload Video</button>
15
+ <br>
16
+ <button id="template-select">Template Select</button>
17
+ <button id="sam-refine">SAM Refine</button>
18
+ <br>
19
+ <button id="track-video">Track Video</button>
20
+ <button id="track-image">Track Image</button>
21
+ <br>
22
+ <a href="/download_video" id="download-video" download>Download Video</a>
23
+
24
+ <script>
25
+ // JavaScript code for handling interactions with the server
26
+ $("#upload-video").click(function() {
27
+ var videoInput = document.getElementById("video-input");
28
+ var formData = new FormData();
29
+ formData.append("video", videoInput.files[0]);
30
+
31
+ $.ajax({
32
+ url: "/upload_video",
33
+ type: "POST",
34
+ data: formData,
35
+ processData: false,
36
+ contentType: false,
37
+ success: function(response) {
38
+ console.log(response);
39
+ // Process the response and update the UI accordingly
40
+ },
41
+ error: function(jqXHR, textStatus, errorThrown) {
42
+ console.log(textStatus, errorThrown);
43
+ }
44
+ });
45
+ });
46
+
47
+ </script>
48
+ </body>
49
+ </html>
50
+
text_server.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import time
5
+ import json
6
+ import queue
7
+ import numpy as np
8
+ import requests
9
+ import concurrent.futures
10
+ from PIL import Image
11
+ from flask import Flask, render_template, request, jsonify, send_file
12
+ import torchvision
13
+ import torch
14
+
15
+ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
16
+ sys.path.append(sys.path[0] + "/tracker")
17
+ sys.path.append(sys.path[0] + "/tracker/model")
18
+ from track_anything import TrackingAnything
19
+ from track_anything import parse_augment
20
+
21
+ # ... (all the functions defined in the original code except the Gradio part)
22
+
23
+ app = Flask(__name__)
24
+ app.config['UPLOAD_FOLDER'] = './uploaded_videos'
25
+ app.config['ALLOWED_EXTENSIONS'] = {'mp4', 'avi', 'mov', 'mkv'}
26
+
27
+
28
+ def allowed_file(filename):
29
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
30
+
31
+ @app.route("/")
32
+ def index():
33
+ return render_template("index.html")
34
+
35
+ @app.route("/upload_video", methods=["POST"])
36
+ def upload_video():
37
+ # ... (handle video upload and processing)
38
+ return jsonify(status="success", data=video_data)
39
+
40
+ @app.route("/template_select", methods=["POST"])
41
+ def template_select():
42
+ # ... (handle template selection and processing)
43
+ return jsonify(status="success", data=template_data)
44
+
45
+ @app.route("/sam_refine", methods=["POST"])
46
+ def sam_refine_request():
47
+ # ... (handle sam refine and processing)
48
+ return jsonify(status="success", data=sam_data)
49
+
50
+ @app.route("/track_video", methods=["POST"])
51
+ def track_video():
52
+ # ... (handle video tracking and processing)
53
+ return jsonify(status="success", data=tracking_data)
54
+
55
+ @app.route("/track_image", methods=["POST"])
56
+ def track_image():
57
+ # ... (handle image tracking and processing)
58
+ return jsonify(status="success", data=tracking_data)
59
+
60
+ @app.route("/download_video", methods=["GET"])
61
+ def download_video():
62
+ try:
63
+ return send_file("output.mp4", attachment_filename="output.mp4")
64
+ except Exception as e:
65
+ return str(e)
66
+
67
+ if __name__ == "__main__":
68
+ app.run(debug=True, host="0.0.0.0", port=args.port)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ app.run(host="0.0.0.0",port=12212, debug=True)
tools/__init__.py ADDED
File without changes
tools/base_segmenter.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15
+ """
16
+ device: model device
17
+ SAM_checkpoint: path of SAM checkpoint
18
+ model_type: vit_b, vit_l, vit_h
19
+ """
20
+ print(f"Initializing BaseSegmenter to {device}")
21
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22
+
23
+ self.device = device
24
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
26
+ self.model.to(device=self.device)
27
+ self.predictor = SamPredictor(self.model)
28
+ self.embedded = False
29
+
30
+ @torch.no_grad()
31
+ def set_image(self, image: np.ndarray):
32
+ # PIL.open(image_path) 3channel: RGB
33
+ # image embedding: avoid encode the same image multiple times
34
+ self.orignal_image = image
35
+ if self.embedded:
36
+ print('repeat embedding, please reset_image.')
37
+ return
38
+ self.predictor.set_image(image)
39
+ self.embedded = True
40
+ return
41
+
42
+ @torch.no_grad()
43
+ def reset_image(self):
44
+ # reset image embeding
45
+ self.predictor.reset_image()
46
+ self.embedded = False
47
+
48
+ def predict(self, prompts, mode, multimask=True):
49
+ """
50
+ image: numpy array, h, w, 3
51
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
52
+ prompts['point_coords']: numpy array [N,2]
53
+ prompts['point_labels']: numpy array [1,N]
54
+ prompts['mask_input']: numpy array [1,256,256]
55
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
56
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
57
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
58
+ """
59
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
60
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
61
+
62
+ if mode == 'point':
63
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
64
+ point_labels=prompts['point_labels'],
65
+ multimask_output=multimask)
66
+ elif mode == 'mask':
67
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
68
+ multimask_output=multimask)
69
+ elif mode == 'both': # both
70
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
71
+ point_labels=prompts['point_labels'],
72
+ mask_input=prompts['mask_input'],
73
+ multimask_output=multimask)
74
+ else:
75
+ raise("Not implement now!")
76
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
77
+ return masks, scores, logits
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # load and show an image
82
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
83
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
84
+
85
+ # initialise BaseSegmenter
86
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
87
+ model_type = 'vit_h'
88
+ device = "cuda:4"
89
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
90
+
91
+ # image embedding (once embedded, multiple prompts can be applied)
92
+ base_segmenter.set_image(image)
93
+
94
+ # examples
95
+ # point only ------------------------
96
+ mode = 'point'
97
+ prompts = {
98
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
99
+ 'point_labels': np.array([1, 1]),
100
+ }
101
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
102
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
103
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
104
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
105
+
106
+ # both ------------------------
107
+ mode = 'both'
108
+ mask_input = logits[np.argmax(scores), :, :]
109
+ prompts = {'mask_input': mask_input [None, :, :]}
110
+ prompts = {
111
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
112
+ 'point_labels': np.array([1, 0]),
113
+ 'mask_input': mask_input[None, :, :]
114
+ }
115
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
116
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
117
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
118
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
119
+
120
+ # mask only ------------------------
121
+ mode = 'mask'
122
+ mask_input = logits[np.argmax(scores), :, :]
123
+
124
+ prompts = {'mask_input': mask_input[None, :, :]}
125
+
126
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
127
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
128
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
129
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
tools/interact_tools.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter as mask_painter2
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+ import os
14
+ import requests
15
+ import sys
16
+
17
+
18
+ mask_color = 3
19
+ mask_alpha = 0.7
20
+ contour_color = 1
21
+ contour_width = 5
22
+ point_color_ne = 8
23
+ point_color_ps = 50
24
+ point_alpha = 0.9
25
+ point_radius = 15
26
+ contour_color = 2
27
+ contour_width = 5
28
+
29
+
30
+ class SamControler():
31
+ def __init__(self, SAM_checkpoint, model_type, device):
32
+ '''
33
+ initialize sam controler
34
+ '''
35
+
36
+
37
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
38
+
39
+
40
+ def seg_again(self, image: np.ndarray):
41
+ '''
42
+ it is used when interact in video
43
+ '''
44
+ self.sam_controler.reset_image()
45
+ self.sam_controler.set_image(image)
46
+ return
47
+
48
+
49
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
50
+ '''
51
+ it is used in first frame in video
52
+ return: mask, logit, painted image(mask+point)
53
+ '''
54
+ # self.sam_controler.set_image(image)
55
+ origal_image = self.sam_controler.orignal_image
56
+ neg_flag = labels[-1]
57
+ if neg_flag==1:
58
+ #find neg
59
+ prompts = {
60
+ 'point_coords': points,
61
+ 'point_labels': labels,
62
+ }
63
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
64
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
65
+ prompts = {
66
+ 'point_coords': points,
67
+ 'point_labels': labels,
68
+ 'mask_input': logit[None, :, :]
69
+ }
70
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
71
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
72
+ else:
73
+ #find positive
74
+ prompts = {
75
+ 'point_coords': points,
76
+ 'point_labels': labels,
77
+ }
78
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
79
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
80
+
81
+
82
+ assert len(points)==len(labels)
83
+
84
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
85
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
86
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
87
+ painted_image = Image.fromarray(painted_image)
88
+
89
+ return mask, logit, painted_image
90
+
91
+ def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
92
+ origal_image = self.sam_controler.orignal_image
93
+ if same:
94
+ '''
95
+ true; loop in the same image
96
+ '''
97
+ prompts = {
98
+ 'point_coords': points,
99
+ 'point_labels': labels,
100
+ 'mask_input': logits[None, :, :]
101
+ }
102
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
103
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
104
+
105
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
106
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
107
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
108
+ painted_image = Image.fromarray(painted_image)
109
+
110
+ return mask, logit, painted_image
111
+ else:
112
+ '''
113
+ loop in the different image, interact in the video
114
+ '''
115
+ if image is None:
116
+ raise('Image error')
117
+ else:
118
+ self.seg_again(image)
119
+ prompts = {
120
+ 'point_coords': points,
121
+ 'point_labels': labels,
122
+ }
123
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
124
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
125
+
126
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
127
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
128
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
129
+ painted_image = Image.fromarray(painted_image)
130
+
131
+ return mask, logit, painted_image
132
+
133
+
134
+
135
+
136
+
137
+
138
+ # def initialize():
139
+ # '''
140
+ # initialize sam controler
141
+ # '''
142
+ # checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
143
+ # folder = "segmenter"
144
+ # SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
145
+ # download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
146
+
147
+
148
+ # model_type = 'vit_h'
149
+ # device = "cuda:0"
150
+ # sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
151
+ # return sam_controler
152
+
153
+
154
+ # def seg_again(sam_controler, image: np.ndarray):
155
+ # '''
156
+ # it is used when interact in video
157
+ # '''
158
+ # sam_controler.reset_image()
159
+ # sam_controler.set_image(image)
160
+ # return
161
+
162
+
163
+ # def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
164
+ # '''
165
+ # it is used in first frame in video
166
+ # return: mask, logit, painted image(mask+point)
167
+ # '''
168
+ # sam_controler.set_image(image)
169
+ # prompts = {
170
+ # 'point_coords': points,
171
+ # 'point_labels': labels,
172
+ # }
173
+ # masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
174
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
175
+
176
+ # assert len(points)==len(labels)
177
+
178
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
179
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
180
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
181
+ # painted_image = Image.fromarray(painted_image)
182
+
183
+ # return mask, logit, painted_image
184
+
185
+ # def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
186
+ # if same:
187
+ # '''
188
+ # true; loop in the same image
189
+ # '''
190
+ # prompts = {
191
+ # 'point_coords': points,
192
+ # 'point_labels': labels,
193
+ # 'mask_input': logits[None, :, :]
194
+ # }
195
+ # masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
196
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
197
+
198
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
199
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
200
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
201
+ # painted_image = Image.fromarray(painted_image)
202
+
203
+ # return mask, logit, painted_image
204
+ # else:
205
+ # '''
206
+ # loop in the different image, interact in the video
207
+ # '''
208
+ # if image is None:
209
+ # raise('Image error')
210
+ # else:
211
+ # seg_again(sam_controler, image)
212
+ # prompts = {
213
+ # 'point_coords': points,
214
+ # 'point_labels': labels,
215
+ # }
216
+ # masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
217
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
218
+
219
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
220
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
221
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
222
+ # painted_image = Image.fromarray(painted_image)
223
+
224
+ # return mask, logit, painted_image
225
+
226
+
227
+
228
+
229
+ if __name__ == "__main__":
230
+ points = np.array([[500, 375], [1125, 625]])
231
+ labels = np.array([1, 1])
232
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
233
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
234
+
235
+ sam_controler = initialize()
236
+ mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
237
+ painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
238
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
239
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
240
+ cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
241
+ painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
242
+
243
+ mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
244
+ painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
245
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
246
+ cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
247
+ painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
248
+
249
+ mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
250
+ painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
251
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
252
+ cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
253
+ painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
+
263
+
264
+
265
+
tools/mask_painter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import copy
6
+ import time
7
+
8
+
9
+ def colormap(rgb=True):
10
+ color_list = np.array(
11
+ [
12
+ 0.000, 0.000, 0.000,
13
+ 1.000, 1.000, 1.000,
14
+ 1.000, 0.498, 0.313,
15
+ 0.392, 0.581, 0.929,
16
+ 0.000, 0.447, 0.741,
17
+ 0.850, 0.325, 0.098,
18
+ 0.929, 0.694, 0.125,
19
+ 0.494, 0.184, 0.556,
20
+ 0.466, 0.674, 0.188,
21
+ 0.301, 0.745, 0.933,
22
+ 0.635, 0.078, 0.184,
23
+ 0.300, 0.300, 0.300,
24
+ 0.600, 0.600, 0.600,
25
+ 1.000, 0.000, 0.000,
26
+ 1.000, 0.500, 0.000,
27
+ 0.749, 0.749, 0.000,
28
+ 0.000, 1.000, 0.000,
29
+ 0.000, 0.000, 1.000,
30
+ 0.667, 0.000, 1.000,
31
+ 0.333, 0.333, 0.000,
32
+ 0.333, 0.667, 0.000,
33
+ 0.333, 1.000, 0.000,
34
+ 0.667, 0.333, 0.000,
35
+ 0.667, 0.667, 0.000,
36
+ 0.667, 1.000, 0.000,
37
+ 1.000, 0.333, 0.000,
38
+ 1.000, 0.667, 0.000,
39
+ 1.000, 1.000, 0.000,
40
+ 0.000, 0.333, 0.500,
41
+ 0.000, 0.667, 0.500,
42
+ 0.000, 1.000, 0.500,
43
+ 0.333, 0.000, 0.500,
44
+ 0.333, 0.333, 0.500,
45
+ 0.333, 0.667, 0.500,
46
+ 0.333, 1.000, 0.500,
47
+ 0.667, 0.000, 0.500,
48
+ 0.667, 0.333, 0.500,
49
+ 0.667, 0.667, 0.500,
50
+ 0.667, 1.000, 0.500,
51
+ 1.000, 0.000, 0.500,
52
+ 1.000, 0.333, 0.500,
53
+ 1.000, 0.667, 0.500,
54
+ 1.000, 1.000, 0.500,
55
+ 0.000, 0.333, 1.000,
56
+ 0.000, 0.667, 1.000,
57
+ 0.000, 1.000, 1.000,
58
+ 0.333, 0.000, 1.000,
59
+ 0.333, 0.333, 1.000,
60
+ 0.333, 0.667, 1.000,
61
+ 0.333, 1.000, 1.000,
62
+ 0.667, 0.000, 1.000,
63
+ 0.667, 0.333, 1.000,
64
+ 0.667, 0.667, 1.000,
65
+ 0.667, 1.000, 1.000,
66
+ 1.000, 0.000, 1.000,
67
+ 1.000, 0.333, 1.000,
68
+ 1.000, 0.667, 1.000,
69
+ 0.167, 0.000, 0.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.143, 0.143, 0.143,
88
+ 0.286, 0.286, 0.286,
89
+ 0.429, 0.429, 0.429,
90
+ 0.571, 0.571, 0.571,
91
+ 0.714, 0.714, 0.714,
92
+ 0.857, 0.857, 0.857
93
+ ]
94
+ ).astype(np.float32)
95
+ color_list = color_list.reshape((-1, 3)) * 255
96
+ if not rgb:
97
+ color_list = color_list[:, ::-1]
98
+ return color_list
99
+
100
+
101
+ color_list = colormap()
102
+ color_list = color_list.astype('uint8').tolist()
103
+
104
+
105
+ def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106
+ background_color = np.array(background_color)
107
+ contour_color = np.array(contour_color)
108
+
109
+ # background_mask = 1 - background_mask
110
+ # contour_mask = 1 - contour_mask
111
+
112
+ for i in range(3):
113
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
115
+
116
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118
+
119
+ return image.astype('uint8')
120
+
121
+
122
+ def mask_generator_00(mask, background_radius, contour_radius):
123
+ # no background width when '00'
124
+ # distance map
125
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127
+ dist_map = dist_transform_fore - dist_transform_back
128
+ # ...:::!!!:::...
129
+ contour_radius += 2
130
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131
+ contour_mask = contour_mask / np.max(contour_mask)
132
+ contour_mask[contour_mask>0.5] = 1.
133
+
134
+ return mask, contour_mask
135
+
136
+
137
+ def mask_generator_01(mask, background_radius, contour_radius):
138
+ # no background width when '00'
139
+ # distance map
140
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142
+ dist_map = dist_transform_fore - dist_transform_back
143
+ # ...:::!!!:::...
144
+ contour_radius += 2
145
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146
+ contour_mask = contour_mask / np.max(contour_mask)
147
+ return mask, contour_mask
148
+
149
+
150
+ def mask_generator_10(mask, background_radius, contour_radius):
151
+ # distance map
152
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154
+ dist_map = dist_transform_fore - dist_transform_back
155
+ # .....:::::!!!!!
156
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
157
+ background_mask = (background_mask - np.min(background_mask))
158
+ background_mask = background_mask / np.max(background_mask)
159
+ # ...:::!!!:::...
160
+ contour_radius += 2
161
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162
+ contour_mask = contour_mask / np.max(contour_mask)
163
+ contour_mask[contour_mask>0.5] = 1.
164
+ return background_mask, contour_mask
165
+
166
+
167
+ def mask_generator_11(mask, background_radius, contour_radius):
168
+ # distance map
169
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171
+ dist_map = dist_transform_fore - dist_transform_back
172
+ # .....:::::!!!!!
173
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
174
+ background_mask = (background_mask - np.min(background_mask))
175
+ background_mask = background_mask / np.max(background_mask)
176
+ # ...:::!!!:::...
177
+ contour_radius += 2
178
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179
+ contour_mask = contour_mask / np.max(contour_mask)
180
+ return background_mask, contour_mask
181
+
182
+
183
+ def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184
+ """
185
+ Input:
186
+ input_image: numpy array
187
+ input_mask: numpy array
188
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189
+ background_blur_radius: radius of background blur, must be odd number
190
+ contour_width: width of mask contour, must be odd number
191
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194
+
195
+ Output:
196
+ painted_image: numpy array
197
+ """
198
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
199
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201
+
202
+ # downsample input image and mask
203
+ width, height = input_image.shape[0], input_image.shape[1]
204
+ res = 1024
205
+ ratio = min(1.0 * res / max(width, height), 1.0)
206
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208
+
209
+ # 0: background, 1: foreground
210
+ msk = np.clip(input_mask, 0, 1)
211
+
212
+ # generate masks for background and contour pixels
213
+ background_radius = (background_blur_radius - 1) // 2
214
+ contour_radius = (contour_width - 1) // 2
215
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217
+
218
+ # paint
219
+ painted_image = vis_add_mask\
220
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221
+
222
+ return painted_image
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228
+ background_blur_radius = 31 # radius of background blur, must be odd number
229
+ contour_width = 11 # contour width, must be odd number
230
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
232
+
233
+ # load input image and mask
234
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236
+
237
+ # paint
238
+ overall_time_1 = 0
239
+ overall_time_2 = 0
240
+ overall_time_3 = 0
241
+ overall_time_4 = 0
242
+ overall_time_5 = 0
243
+
244
+ for i in range(50):
245
+ t2 = time.time()
246
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247
+ e2 = time.time()
248
+
249
+ t3 = time.time()
250
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251
+ e3 = time.time()
252
+
253
+ t1 = time.time()
254
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255
+ e1 = time.time()
256
+
257
+ t4 = time.time()
258
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259
+ e4 = time.time()
260
+
261
+ t5 = time.time()
262
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263
+ e5 = time.time()
264
+
265
+ overall_time_1 += (e1 - t1)
266
+ overall_time_2 += (e2 - t2)
267
+ overall_time_3 += (e3 - t3)
268
+ overall_time_4 += (e4 - t4)
269
+ overall_time_5 += (e5 - t5)
270
+
271
+ print(f'average time w gaussian: {overall_time_1/50}')
272
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
273
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
274
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
275
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
276
+
277
+ # save
278
+ painted_image_00 = Image.fromarray(painted_image_00)
279
+ painted_image_00.save('./test_img/painter_output_image_00.png')
280
+
281
+ painted_image_10 = Image.fromarray(painted_image_10)
282
+ painted_image_10.save('./test_img/painter_output_image_10.png')
283
+
284
+ painted_image_01 = Image.fromarray(painted_image_01)
285
+ painted_image_01.save('./test_img/painter_output_image_01.png')
286
+
287
+ painted_image_11 = Image.fromarray(painted_image_11)
288
+ painted_image_11.save('./test_img/painter_output_image_11.png')
tools/painter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paint masks, contours, or points on images, with specified colors
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import copy
7
+ import time
8
+
9
+
10
+ def colormap(rgb=True):
11
+ color_list = np.array(
12
+ [
13
+ 0.000, 0.000, 0.000,
14
+ 1.000, 1.000, 1.000,
15
+ 1.000, 0.498, 0.313,
16
+ 0.392, 0.581, 0.929,
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.167, 0.000, 0.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.286, 0.286, 0.286,
90
+ 0.429, 0.429, 0.429,
91
+ 0.571, 0.571, 0.571,
92
+ 0.714, 0.714, 0.714,
93
+ 0.857, 0.857, 0.857
94
+ ]
95
+ ).astype(np.float32)
96
+ color_list = color_list.reshape((-1, 3)) * 255
97
+ if not rgb:
98
+ color_list = color_list[:, ::-1]
99
+ return color_list
100
+
101
+
102
+ color_list = colormap()
103
+ color_list = color_list.astype('uint8').tolist()
104
+
105
+
106
+ def vis_add_mask(image, mask, color, alpha):
107
+ color = np.array(color_list[color])
108
+ mask = mask > 0.5
109
+ image[mask] = image[mask] * (1-alpha) + color * alpha
110
+ return image.astype('uint8')
111
+
112
+ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113
+ h, w = input_image.shape[:2]
114
+ point_mask = np.zeros((h, w)).astype('uint8')
115
+ for point in input_points:
116
+ point_mask[point[1], point[0]] = 1
117
+
118
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119
+ point_mask = cv2.dilate(point_mask, kernel)
120
+
121
+ contour_radius = (contour_width - 1) // 2
122
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124
+ dist_map = dist_transform_fore - dist_transform_back
125
+ # ...:::!!!:::...
126
+ contour_radius += 2
127
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128
+ contour_mask = contour_mask / np.max(contour_mask)
129
+ contour_mask[contour_mask>0.5] = 1.
130
+
131
+ # paint mask
132
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133
+ # paint contour
134
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135
+ return painted_image
136
+
137
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139
+ # 0: background, 1: foreground
140
+ mask = np.clip(input_mask, 0, 1)
141
+ contour_radius = (contour_width - 1) // 2
142
+
143
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145
+ dist_map = dist_transform_fore - dist_transform_back
146
+ # ...:::!!!:::...
147
+ contour_radius += 2
148
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149
+ contour_mask = contour_mask / np.max(contour_mask)
150
+ contour_mask[contour_mask>0.5] = 1.
151
+
152
+ # paint mask
153
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154
+ # paint contour
155
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156
+
157
+ return painted_image
158
+
159
+ def background_remover(input_image, input_mask):
160
+ """
161
+ input_image: H, W, 3, np.array
162
+ input_mask: H, W, np.array
163
+
164
+ image_wo_background: PIL.Image
165
+ """
166
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167
+ # 0: background, 1: foreground
168
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171
+
172
+ return image_wo_background
173
+
174
+ if __name__ == '__main__':
175
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177
+
178
+ # example of mask painter
179
+ mask_color = 3
180
+ mask_alpha = 0.7
181
+ contour_color = 1
182
+ contour_width = 5
183
+
184
+ # save
185
+ painted_image = Image.fromarray(input_image)
186
+ painted_image.save('images/original.png')
187
+
188
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189
+ # save
190
+ painted_image = Image.fromarray(input_image)
191
+ painted_image.save('images/original1.png')
192
+
193
+ # example of point painter
194
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
196
+ point_color = 5
197
+ point_alpha = 0.9
198
+ point_radius = 15
199
+ contour_color = 2
200
+ contour_width = 5
201
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202
+ # save
203
+ painted_image = Image.fromarray(painted_image_1)
204
+ painted_image.save('images/point_painter_1.png')
205
+
206
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208
+ # save
209
+ painted_image = Image.fromarray(painted_image_2)
210
+ painted_image.save('images/point_painter_2.png')
211
+
212
+ # example of background remover
213
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
214
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215
+ image_wo_background.save('images/image_wo_background.png')
track_anything.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("/hhd3/gaoshang/Track-Anything/tracker")
3
+ import PIL
4
+ from tools.interact_tools import SamControler
5
+ from tracker.base_tracker import BaseTracker
6
+ import numpy as np
7
+ import argparse
8
+
9
+
10
+
11
+ class TrackingAnything():
12
+ def __init__(self, sam_checkpoint, xmem_checkpoint, args):
13
+ self.args = args
14
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
15
+ self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
16
+
17
+
18
+ def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
19
+ same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
20
+ if first_flag:
21
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
22
+ return mask, logit, painted_image
23
+
24
+ if interact_flag:
25
+ mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
26
+ return mask, logit, painted_image
27
+
28
+ mask, logit, painted_image = self.xmem.track(image, logit)
29
+ return mask, logit, painted_image
30
+
31
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
32
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
33
+ return mask, logit, painted_image
34
+
35
+ def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
36
+ mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
37
+ return mask, logit, painted_image
38
+
39
+ def generator(self, images: list, template_mask:np.ndarray):
40
+
41
+ masks = []
42
+ logits = []
43
+ painted_images = []
44
+ for i in range(len(images)):
45
+ if i ==0:
46
+ mask, logit, painted_image = self.xmem.track(images[i], template_mask)
47
+ masks.append(mask)
48
+ logits.append(logit)
49
+ painted_images.append(painted_image)
50
+
51
+ else:
52
+ mask, logit, painted_image = self.xmem.track(images[i])
53
+ masks.append(mask)
54
+ logits.append(logit)
55
+ painted_images.append(painted_image)
56
+ return masks, logits, painted_images
57
+
58
+
59
+ def parse_augment():
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument('--device', type=str, default="cuda:0")
62
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
63
+ parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
64
+ parser.add_argument('--debug', action="store_true")
65
+ parser.add_argument('--mask_save', default=True)
66
+ args = parser.parse_args()
67
+
68
+ if args.debug:
69
+ print(args)
70
+ return args
71
+
72
+
73
+ if __name__ == "__main__":
74
+ masks = None
75
+ logits = None
76
+ painted_images = None
77
+ images = []
78
+ image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
79
+ args = parse_augment()
80
+ # images.append(np.ones((20,20,3)).astype('uint8'))
81
+ # images.append(np.ones((20,20,3)).astype('uint8'))
82
+ images.append(image)
83
+ images.append(image)
84
+
85
+ mask = np.zeros_like(image)[:,:,0]
86
+ mask[0,0]= 1
87
+ trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
88
+ masks, logits ,painted_images= trackany.generator(images, mask)
89
+
90
+
91
+
92
+
93
+
tracker/.DS_Store ADDED
Binary file (6.15 kB). View file
tracker/base_tracker.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import for debugging
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ from PIL import Image
6
+ # import for base_tracker
7
+ import torch
8
+ import yaml
9
+ import torch.nn.functional as F
10
+ from model.network import XMem
11
+ from inference.inference_core import InferenceCore
12
+ from util.mask_mapper import MaskMapper
13
+ from torchvision import transforms
14
+ from util.range_transform import im_normalization
15
+ import sys
16
+ sys.path.insert(0, sys.path[0]+"/../")
17
+ from tools.painter import mask_painter
18
+ from tools.base_segmenter import BaseSegmenter
19
+ from torchvision.transforms import Resize
20
+
21
+
22
+ class BaseTracker:
23
+ def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
24
+ """
25
+ device: model device
26
+ xmem_checkpoint: checkpoint of XMem model
27
+ """
28
+ # load configurations
29
+ with open("tracker/config/config.yaml", 'r') as stream:
30
+ config = yaml.safe_load(stream)
31
+ # initialise XMem
32
+ network = XMem(config, xmem_checkpoint).to(device).eval()
33
+ # initialise IncerenceCore
34
+ self.tracker = InferenceCore(network, config)
35
+ # data transformation
36
+ self.im_transform = transforms.Compose([
37
+ transforms.ToTensor(),
38
+ im_normalization,
39
+ ])
40
+ self.device = device
41
+
42
+ # changable properties
43
+ self.mapper = MaskMapper()
44
+ self.initialised = False
45
+
46
+ # # SAM-based refinement
47
+ # self.sam_model = sam_model
48
+ # self.resizer = Resize([256, 256])
49
+
50
+ @torch.no_grad()
51
+ def resize_mask(self, mask):
52
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
53
+ h, w = mask.shape[-2:]
54
+ min_hw = min(h, w)
55
+ return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
56
+ mode='nearest')
57
+
58
+ @torch.no_grad()
59
+ def track(self, frame, first_frame_annotation=None):
60
+ """
61
+ Input:
62
+ frames: numpy arrays (H, W, 3)
63
+ logit: numpy array (H, W), logit
64
+
65
+ Output:
66
+ mask: numpy arrays (H, W)
67
+ logit: numpy arrays, probability map (H, W)
68
+ painted_image: numpy array (H, W, 3)
69
+ """
70
+ if first_frame_annotation is not None: # first frame mask
71
+ # initialisation
72
+ mask, labels = self.mapper.convert_mask(first_frame_annotation)
73
+ mask = torch.Tensor(mask).to(self.device)
74
+ self.tracker.set_all_labels(list(self.mapper.remappings.values()))
75
+ else:
76
+ mask = None
77
+ labels = None
78
+ # prepare inputs
79
+ frame_tensor = self.im_transform(frame).to(self.device)
80
+ # track one frame
81
+ probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
82
+ # # refine
83
+ # if first_frame_annotation is None:
84
+ # out_mask = self.sam_refinement(frame, logits[1], ti)
85
+
86
+ # convert to mask
87
+ out_mask = torch.argmax(probs, dim=0)
88
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
89
+
90
+ num_objs = out_mask.max()
91
+ painted_image = frame
92
+ for obj in range(1, num_objs+1):
93
+ painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
94
+
95
+ return out_mask, out_mask, painted_image
96
+
97
+ @torch.no_grad()
98
+ def sam_refinement(self, frame, logits, ti):
99
+ """
100
+ refine segmentation results with mask prompt
101
+ """
102
+ # convert to 1, 256, 256
103
+ self.sam_model.set_image(frame)
104
+ mode = 'mask'
105
+ logits = logits.unsqueeze(0)
106
+ logits = self.resizer(logits).cpu().numpy()
107
+ prompts = {'mask_input': logits} # 1 256 256
108
+ masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
109
+ painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8)
110
+ painted_image = Image.fromarray(painted_image)
111
+ painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png')
112
+ self.sam_model.reset_image()
113
+
114
+ @torch.no_grad()
115
+ def clear_memory(self):
116
+ self.tracker.clear_memory()
117
+ self.mapper.clear_labels()
118
+
119
+
120
+ if __name__ == '__main__':
121
+ # video frames (multiple objects)
122
+ video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
123
+ video_path_list.sort()
124
+ # first frame
125
+ first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
126
+ # load frames
127
+ frames = []
128
+ for video_path in video_path_list:
129
+ frames.append(np.array(Image.open(video_path).convert('RGB')))
130
+ frames = np.stack(frames, 0) # N, H, W, C
131
+ # load first frame annotation
132
+ first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
133
+
134
+ # ----------------------------------------------------------
135
+ # initalise tracker
136
+ # ----------------------------------------------------------
137
+ device = 'cuda:4'
138
+ XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
139
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
140
+ model_type = 'vit_h'
141
+
142
+ # sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
143
+ tracker = BaseTracker(XMEM_checkpoint, device, None, device)
144
+
145
+ # test for storage efficiency
146
+ frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
147
+ first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
148
+
149
+ for ti, frame in enumerate(frames):
150
+ print(ti)
151
+ if ti > 200:
152
+ break
153
+ if ti == 0:
154
+ mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
155
+ else:
156
+ mask, prob, painted_image = tracker.track(frame)
157
+ # save
158
+ painted_image = Image.fromarray(painted_image)
159
+ painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
160
+
161
+ tracker.clear_memory()
162
+ for ti, frame in enumerate(frames):
163
+ print(ti)
164
+ # if ti > 200:
165
+ # break
166
+ if ti == 0:
167
+ mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
168
+ else:
169
+ mask, prob, painted_image = tracker.track(frame)
170
+ # save
171
+ painted_image = Image.fromarray(painted_image)
172
+ painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
173
+
174
+ # # track anything given in the first frame annotation
175
+ # for ti, frame in enumerate(frames):
176
+ # if ti == 0:
177
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
178
+ # else:
179
+ # mask, prob, painted_image = tracker.track(frame)
180
+ # # save
181
+ # painted_image = Image.fromarray(painted_image)
182
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
183
+
184
+ # # ----------------------------------------------------------
185
+ # # another video
186
+ # # ----------------------------------------------------------
187
+ # # video frames
188
+ # video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
189
+ # video_path_list.sort()
190
+ # # first frame
191
+ # first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
192
+ # # load frames
193
+ # frames = []
194
+ # for video_path in video_path_list:
195
+ # frames.append(np.array(Image.open(video_path).convert('RGB')))
196
+ # frames = np.stack(frames, 0) # N, H, W, C
197
+ # # load first frame annotation
198
+ # first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
199
+
200
+ # print('first video done. clear.')
201
+
202
+ # tracker.clear_memory()
203
+ # # track anything given in the first frame annotation
204
+ # for ti, frame in enumerate(frames):
205
+ # if ti == 0:
206
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
207
+ # else:
208
+ # mask, prob, painted_image = tracker.track(frame)
209
+ # # save
210
+ # painted_image = Image.fromarray(painted_image)
211
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
212
+
213
+ # # failure case test
214
+ # failure_path = '/ssd1/gaomingqi/failure'
215
+ # frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
216
+ # # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
217
+ # first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
218
+ # first_mask = np.clip(first_mask, 0, 1)
219
+
220
+ # for ti, frame in enumerate(frames):
221
+ # if ti == 0:
222
+ # mask, probs, painted_image = tracker.track(frame, first_mask)
223
+ # else:
224
+ # mask, probs, painted_image = tracker.track(frame)
225
+ # # save
226
+ # painted_image = Image.fromarray(painted_image)
227
+ # painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
228
+ # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
229
+
230
+ # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
231
+
232
+
233
+