oguzakif commited on
Commit
a1d02b9
1 Parent(s): 0c3c286

Init object remover

Browse files
Files changed (4) hide show
  1. FGT_codes +1 -0
  2. SiamMask +1 -0
  3. app.py +259 -0
  4. requirements.txt +20 -0
FGT_codes ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f7bc8a2c520ef862d0a4d28a38334604eb41410c
SiamMask ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 0eaac33050fdcda81c9a25aa307fffa74c182e36
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+ from video_inpainting import video_inpainting
4
+ from tools.test import *
5
+ from custom import Custom
6
+ from types import SimpleNamespace
7
+ import torch
8
+ import numpy as np
9
+ import torchvision
10
+ import cv2
11
+ import sys
12
+ from os.path import exists, join, basename, splitext
13
+ import os
14
+ project_name = './video-object-remover'
15
+
16
+ sys.path.append(project_name)
17
+
18
+ sys.path.append(join(project_name, 'SiamMask',
19
+ 'experiments', 'siammask_sharp'))
20
+ sys.path.append(join(project_name, 'SiamMask', 'models'))
21
+ sys.path.append(join(project_name, 'SiamMask'))
22
+
23
+ exp_path = join(project_name, 'SiamMask/experiments/siammask_sharp')
24
+ pretrained_path1 = join(exp_path, 'SiamMask_DAVIS.pth')
25
+
26
+ sys.path.append(join(project_name, 'FGT_codes'))
27
+ sys.path.append(join(project_name, 'FGT_codes', 'tool'))
28
+ sys.path.append(join(project_name, 'FGT_codes', 'LAFC', 'flowCheckPoint'))
29
+ sys.path.append(join(project_name, 'FGT_codes', 'LAFC', 'checkpoint'))
30
+ sys.path.append(join(project_name, 'FGT_codes', 'FGT', 'checkpoint'))
31
+ sys.path.append(join(project_name, 'FGT_codes', 'LAFC',
32
+ 'flowCheckPoint', 'raft-things.pth'))
33
+
34
+ torch.set_grad_enabled(False)
35
+
36
+ # init SiamMask
37
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+ cfg = load_config(SimpleNamespace(config=join(exp_path, 'config_davis.json')))
39
+ siammask = Custom(anchors=cfg['anchors'])
40
+ siammask = load_pretrain(siammask, pretrained_path1)
41
+ siammask = siammask.eval().to(device)
42
+
43
+ # constants
44
+ object_x = 0
45
+ object_y = 0
46
+ object_width = 0
47
+ object_height = 0
48
+ original_frame_list = []
49
+ mask_list = []
50
+
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--opt', default='configs/object_removal.yaml',
53
+ help='Please select your config file for inference')
54
+ # video completion
55
+ parser.add_argument('--mode', default='object_removal', choices=[
56
+ 'object_removal', 'watermark_removal', 'video_extrapolation'], help="modes: object_removal / video_extrapolation")
57
+ parser.add_argument(
58
+ '--path', default='/myData/davis_resized/walking', help="dataset for evaluation")
59
+ parser.add_argument(
60
+ '--path_mask', default='/myData/dilateAnnotations_4/walking', help="mask for object removal")
61
+ parser.add_argument(
62
+ '--outroot', default='quick_start/walking3', help="output directory")
63
+ parser.add_argument('--consistencyThres', dest='consistencyThres', default=5, type=float,
64
+ help='flow consistency error threshold')
65
+ parser.add_argument('--alpha', dest='alpha', default=0.1, type=float)
66
+ parser.add_argument('--Nonlocal', dest='Nonlocal',
67
+ default=False, type=bool)
68
+
69
+ # RAFT
70
+ parser.add_argument(
71
+ '--raft_model', default='../LAFC/flowCheckPoint/raft-things.pth', help="restore checkpoint")
72
+ parser.add_argument('--small', action='store_true', help='use small model')
73
+ parser.add_argument('--mixed_precision',
74
+ action='store_true', help='use mixed precision')
75
+ parser.add_argument('--alternate_corr', action='store_true',
76
+ help='use efficent correlation implementation')
77
+
78
+ # LAFC
79
+ parser.add_argument('--lafc_ckpts', type=str, default='../LAFC/checkpoint')
80
+
81
+ # FGT
82
+ parser.add_argument('--fgt_ckpts', type=str, default='../FGT/checkpoint')
83
+
84
+ # extrapolation
85
+ parser.add_argument('--H_scale', dest='H_scale', default=2,
86
+ type=float, help='H extrapolation scale')
87
+ parser.add_argument('--W_scale', dest='W_scale', default=2,
88
+ type=float, help='W extrapolation scale')
89
+
90
+ # Image basic information
91
+ parser.add_argument('--imgH', type=int, default=256)
92
+ parser.add_argument('--imgW', type=int, default=432)
93
+ parser.add_argument('--flow_mask_dilates', type=int, default=8)
94
+ parser.add_argument('--frame_dilates', type=int, default=0)
95
+
96
+ parser.add_argument('--gpu', type=int, default=0)
97
+
98
+ # FGT inference parameters
99
+ parser.add_argument('--step', type=int, default=10)
100
+ parser.add_argument('--num_ref', type=int, default=-1)
101
+ parser.add_argument('--neighbor_stride', type=int, default=5)
102
+
103
+ # visualization
104
+ parser.add_argument('--vis_flows', action='store_true',
105
+ help='Visualize the initialized flows')
106
+ parser.add_argument('--vis_completed_flows',
107
+ action='store_true', help='Visualize the completed flows')
108
+ parser.add_argument('--vis_prop', action='store_true',
109
+ help='Visualize the frames after stage-I filling (flow guided content propagation)')
110
+ parser.add_argument('--vis_frame', action='store_true',
111
+ help='Visualize frames')
112
+
113
+ args = parser.parse_args()
114
+
115
+
116
+ def getBoundaries(mask):
117
+ if mask is None:
118
+ return 0, 0, 0, 0
119
+
120
+ indexes = np.where((mask == [255, 255, 255]).all(axis=2))
121
+ print(indexes)
122
+ x1 = min(indexes[1])
123
+ y1 = min(indexes[0])
124
+ x2 = max(indexes[1])
125
+ y2 = max(indexes[0])
126
+
127
+ return x1, y1, (x2-x1), (y2-y1)
128
+
129
+
130
+ def track_and_mask(vid, original_frame, masked_frame):
131
+ x, y, w, h = getBoundaries(masked_frame)
132
+ f = 0
133
+
134
+ video_capture = cv2.VideoCapture()
135
+ if video_capture.open(vid):
136
+ width, height = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
137
+ video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
138
+ fps = video_capture.get(cv2.CAP_PROP_FPS)
139
+
140
+ # can't write out mp4, so try to write into an AVI file
141
+ video_writer = cv2.VideoWriter(
142
+ "output.avi", cv2.VideoWriter_fourcc(*'MP42'), fps, (width, height))
143
+ video_writer2 = cv2.VideoWriter(
144
+ "output_mask.avi", cv2.VideoWriter_fourcc(*'MP42'), fps, (width, height))
145
+
146
+ while video_capture.isOpened():
147
+ ret, frame = video_capture.read()
148
+
149
+ if not ret:
150
+ break
151
+
152
+ # frame = cv2.resize(frame, (w - w % 8, h - h % 8))
153
+ if f == 0:
154
+ target_pos = np.array([x + w / 2, y + h / 2])
155
+ target_sz = np.array([w, h])
156
+ # init tracker
157
+ state = siamese_init(
158
+ frame, target_pos, target_sz, siammask, cfg['hp'], device=device)
159
+ else:
160
+ # track
161
+ state = siamese_track(
162
+ state, frame, mask_enable=True, refine_enable=True, device=device)
163
+ location = state['ploygon'].flatten()
164
+ mask = state['mask'] > state['p'].seg_thr
165
+ frame[:, :, 2] = (mask > 0) * 255 + \
166
+ (mask == 0) * frame[:, :, 2]
167
+
168
+ mask = mask.astype(np.uint8) # convert to an unsigned byte
169
+ mask = mask * 255
170
+ mask_list.append(mask)
171
+ cv2.polylines(frame, [np.int0(location).reshape(
172
+ (-1, 1, 2))], True, (0, 255, 0), 3)
173
+
174
+ original_frame_list.append(frame)
175
+ mask_list.append(mask)
176
+
177
+ video_writer.write(frame)
178
+ video_writer2.write(mask)
179
+ f = f + 1
180
+
181
+ video_capture.release()
182
+ video_writer.release()
183
+ video_writer2.release()
184
+
185
+ else:
186
+ print("can't open the given input video file!")
187
+
188
+ return "output.mp4"
189
+
190
+
191
+ def inpaint_video():
192
+ video_inpainting(args, original_frame_list, mask_list)
193
+
194
+ return "output.mp4"
195
+
196
+
197
+ def get_first_frame(video):
198
+ video_capture = cv2.VideoCapture()
199
+ if video_capture.open(video):
200
+ width, height = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(
201
+ video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
202
+
203
+ if video_capture.isOpened():
204
+ ret, frame = video_capture.read()
205
+
206
+ RGB_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
207
+ return RGB_frame
208
+
209
+
210
+ def drawRectangle(frame, mask):
211
+ x1, y1, x2, y2 = getBoundaries(mask)
212
+
213
+ return cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
214
+
215
+
216
+ def getStartEndPoints(mask):
217
+ if mask is None:
218
+ return 0, 0, 0, 0
219
+
220
+ indexes = np.where((mask == [255, 255, 255]).all(axis=2))
221
+ print(indexes)
222
+ x1 = min(indexes[1])
223
+ y1 = min(indexes[0])
224
+ x2 = max(indexes[1])
225
+ y2 = max(indexes[0])
226
+
227
+ return x1, y1, x2, y2
228
+
229
+
230
+ with gr.Blocks() as demo:
231
+ with gr.Row():
232
+ with gr.Column(scale=2):
233
+ with gr.Row():
234
+ in_video = gr.Video()
235
+ with gr.Row():
236
+ first_frame = gr.ImageMask()
237
+ with gr.Row():
238
+ approve_mask = gr.Button(value="Approve Mask")
239
+ with gr.Column(scale=1):
240
+ with gr.Row():
241
+ original_image = gr.Image(interactive=False)
242
+ with gr.Row():
243
+ masked_image = gr.Image(interactive=False)
244
+ with gr.Column(scale=2):
245
+ out_video = gr.Video()
246
+ out_video_inpaint = gr.Video()
247
+ track_mask = gr.Button(value="Track and Mask")
248
+ inpaint = gr.Button(value="Inpaint")
249
+
250
+ in_video.change(fn=get_first_frame, inputs=[
251
+ in_video], outputs=[first_frame])
252
+ approve_mask.click(lambda x: [x['image'], x['mask']], first_frame, [
253
+ original_image, masked_image])
254
+ track_mask.click(fn=track_and_mask, inputs=[
255
+ in_video, original_image, masked_image], outputs=[out_video])
256
+ inpaint.click(fn=inpaint_video, outputs=[out_video_inpaint])
257
+
258
+
259
+ demo.launch(share=True, debug=True)
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ cvbase==0.5.5
4
+ imageio==2.6.1
5
+ matplotlib==3.1.1
6
+ numpy==1.22.2
7
+ opencv-python
8
+ Pillow
9
+ PyYAML
10
+ scikit-image
11
+ scipy
12
+ tensorboardX
13
+ imageio-ffmpeg
14
+ Cython==0.29.34
15
+ colorama==0.3.9
16
+ requests==2.21.0
17
+ fire==0.1.3
18
+ numba==0.39.0
19
+ h5py==2.8.0
20
+ tqdm==4.29.1