watchtowerss commited on
Commit
05187ec
1 Parent(s): c2afc01

huggingface -- version 2

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
36
  assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
37
  assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
 
35
  assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
36
  assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
37
  assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ test_sample/test-sample1.mp4 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -17,7 +17,7 @@ import torchvision
17
  import torch
18
  import concurrent.futures
19
  import queue
20
-
21
  # download checkpoints
22
  def download_checkpoint(url, folder, filename):
23
  os.makedirs(folder, exist_ok=True)
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
84
  "masks": [None]*len(frames),
85
  "logits": [None]*len(frames),
86
  "select_frame_number": 0,
87
- "fps": 30
88
  }
89
- return video_state, gr.update(visible=True, maximum=len(frames), value=1)
 
 
 
 
 
 
 
 
 
90
 
91
  # get the select frame from gradio slider
92
- def select_template(image_selection_slider, video_state):
93
 
94
  # images = video_state[1]
95
  image_selection_slider -= 1
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
100
  model.samcontroler.sam_controler.reset_image()
101
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
102
 
 
 
103
 
104
- return video_state["painted_images"][image_selection_slider], video_state
 
 
 
 
105
 
106
  # use sam to get the mask
107
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
@@ -133,17 +148,65 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
133
 
134
  return painted_image, video_state, interactive_state
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # tracking vos
137
- def vos_tracking_video(video_state, interactive_state):
138
  model.xmem.clear_memory()
139
- following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
140
- template_mask = video_state["masks"][video_state["select_frame_number"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  fps = video_state["fps"]
142
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
143
 
144
- video_state["masks"][video_state["select_frame_number"]:] = masks
145
- video_state["logits"][video_state["select_frame_number"]:] = logits
146
- video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
 
 
 
 
 
147
 
148
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
149
  interactive_state["inference_times"] += 1
@@ -152,7 +215,7 @@ def vos_tracking_video(video_state, interactive_state):
152
  interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
153
  interactive_state["positive_click_times"],
154
  interactive_state["negative_click_times"]))
155
-
156
  #### shanggao code for mask save
157
  if interactive_state["mask_save"]:
158
  if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
176
  output_path (str): The path to save the generated video.
177
  fps (int, optional): The frame rate of the output video. Defaults to 30.
178
  """
 
 
 
 
 
 
 
 
179
  frames = torch.from_numpy(np.asarray(frames))
180
  if not os.path.exists(os.path.dirname(output_path)):
181
  os.makedirs(os.path.dirname(output_path))
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
193
 
194
  # args, defined in track_anything.py
195
  args = parse_augment()
196
- # args.port = 12212
197
- # args.device = "cuda:4"
198
  # args.mask_save = True
199
 
200
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
208
  "inference_times": 0,
209
  "negative_click_times" : 0,
210
  "positive_click_times": 0,
211
- "mask_save": args.mask_save
212
- })
 
 
 
 
 
 
 
213
  video_state = gr.State(
214
  {
215
  "video_name": "",
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
225
  with gr.Row():
226
 
227
  # for user video input
228
- with gr.Column(scale=1.0):
229
- video_input = gr.Video().style(height=360)
 
 
230
 
231
 
232
 
233
- with gr.Row(scale=1):
234
  # put the template frame under the radio button
235
- with gr.Column(scale=0.5):
236
  # extract frames
237
  with gr.Column():
238
  extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
239
 
240
  # click points settins, negative or positive, mode continuous or single
241
  with gr.Row():
242
- with gr.Row(scale=0.5):
243
  point_prompt = gr.Radio(
244
  choices=["Positive", "Negative"],
245
  value="Positive",
246
  label="Point Prompt",
247
- interactive=True)
 
248
  click_mode = gr.Radio(
249
  choices=["Continuous", "Single"],
250
  value="Continuous",
251
  label="Clicking Mode",
252
- interactive=True)
253
- with gr.Row(scale=0.5):
254
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
255
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
256
- template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
257
- image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False)
258
-
259
-
260
-
261
 
262
- with gr.Column(scale=0.5):
263
- video_output = gr.Video().style(height=360)
264
- tracking_video_predict_button = gr.Button(value="Tracking")
 
 
265
 
266
  # first step: get the video information
267
  extract_frames_button.click(
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
269
  inputs=[
270
  video_input, video_state
271
  ],
272
- outputs=[video_state, image_selection_slider],
 
 
273
  )
274
 
275
  # second step: select images from slider
276
  image_selection_slider.release(fn=select_template,
277
- inputs=[image_selection_slider, video_state],
278
- outputs=[template_frame, video_state], api_name="select_image")
 
 
 
279
 
280
-
281
  template_frame.select(
282
  fn=sam_refine,
283
  inputs=[video_state, point_prompt, click_state, interactive_state],
284
  outputs=[template_frame, video_state, interactive_state]
285
  )
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  tracking_video_predict_button.click(
288
  fn=vos_tracking_video,
289
- inputs=[video_state, interactive_state],
290
  outputs=[video_output, video_state, interactive_state]
291
  )
292
 
 
 
 
 
 
 
293
 
294
  # clear input
295
  video_input.clear(
@@ -306,57 +413,43 @@ with gr.Blocks() as iface:
306
  "inference_times": 0,
307
  "negative_click_times" : 0,
308
  "positive_click_times": 0,
309
- "mask_save": args.mask_save
310
- },
311
- [[],[]]
312
- ),
313
- [],
314
- [
315
- video_state,
316
- interactive_state,
317
- click_state,
318
- ],
319
- queue=False,
320
- show_progress=False
321
- )
322
- clear_button_image.click(
323
- lambda: (
324
- {
325
- "origin_images": None,
326
- "painted_images": None,
327
- "masks": None,
328
- "logits": None,
329
- "select_frame_number": 0,
330
- "fps": 30
331
  },
332
- {
333
- "inference_times": 0,
334
- "negative_click_times" : 0,
335
- "positive_click_times": 0,
336
- "mask_save": args.mask_save
337
  },
338
- [[],[]]
339
- ),
 
 
 
 
 
 
340
  [],
341
  [
342
  video_state,
343
  interactive_state,
344
  click_state,
 
 
 
 
345
  ],
346
-
347
  queue=False,
348
- show_progress=False
 
 
 
 
 
 
349
 
350
- )
351
- clear_button_clike.click(
352
- lambda: ([[],[]]),
353
- [],
354
- [click_state],
355
- queue=False,
356
- show_progress=False
357
  )
358
  iface.queue(concurrency_count=1)
359
- iface.launch(enable_queue=True)
360
 
361
 
362
 
17
  import torch
18
  import concurrent.futures
19
  import queue
20
+ from tools.painter import mask_painter, point_painter
21
  # download checkpoints
22
  def download_checkpoint(url, folder, filename):
23
  os.makedirs(folder, exist_ok=True)
84
  "masks": [None]*len(frames),
85
  "logits": [None]*len(frames),
86
  "select_frame_number": 0,
87
+ "fps": fps
88
  }
89
+ video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
90
+
91
+ model.samcontroler.sam_controler.reset_image()
92
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
93
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
94
+ gr.update(visible=True), gr.update(visible=True), \
95
+ gr.update(visible=True), gr.update(visible=True), \
96
+ gr.update(visible=True), gr.update(visible=True), \
97
+ gr.update(visible=True), gr.update(visible=True), \
98
+ gr.update(visible=True)
99
 
100
  # get the select frame from gradio slider
101
+ def select_template(image_selection_slider, video_state, interactive_state):
102
 
103
  # images = video_state[1]
104
  image_selection_slider -= 1
109
  model.samcontroler.sam_controler.reset_image()
110
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
111
 
112
+ # # clear multi mask
113
+ # interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
114
 
115
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state
116
+
117
+ def get_end_number(track_pause_number_slider, interactive_state):
118
+ interactive_state["track_end_number"] = track_pause_number_slider
119
+ return interactive_state
120
 
121
  # use sam to get the mask
122
  def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
148
 
149
  return painted_image, video_state, interactive_state
150
 
151
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
152
+ mask = video_state["masks"][video_state["select_frame_number"]]
153
+ interactive_state["multi_mask"]["masks"].append(mask)
154
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
155
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
156
+ select_frame = show_mask(video_state, interactive_state, mask_dropdown)
157
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
158
+
159
+ def clear_click(video_state, click_state):
160
+ click_state = [[],[]]
161
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
162
+ return template_frame, click_state
163
+
164
+ def remove_multi_mask(interactive_state):
165
+ interactive_state["multi_mask"]["mask_names"]= []
166
+ interactive_state["multi_mask"]["masks"] = []
167
+ return interactive_state
168
+
169
+ def show_mask(video_state, interactive_state, mask_dropdown):
170
+ mask_dropdown.sort()
171
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
172
+
173
+ for i in range(len(mask_dropdown)):
174
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
175
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
176
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
177
+
178
+ return select_frame
179
+
180
  # tracking vos
181
+ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
182
  model.xmem.clear_memory()
183
+ if interactive_state["track_end_number"]:
184
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
185
+ else:
186
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
187
+
188
+ if interactive_state["multi_mask"]["masks"]:
189
+ if len(mask_dropdown) == 0:
190
+ mask_dropdown = ["mask_001"]
191
+ mask_dropdown.sort()
192
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
193
+ for i in range(1,len(mask_dropdown)):
194
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
195
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
196
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
197
+ else:
198
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
199
  fps = video_state["fps"]
200
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
201
 
202
+ if interactive_state["track_end_number"]:
203
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
204
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
205
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
206
+ else:
207
+ video_state["masks"][video_state["select_frame_number"]:] = masks
208
+ video_state["logits"][video_state["select_frame_number"]:] = logits
209
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
210
 
211
  video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
212
  interactive_state["inference_times"] += 1
215
  interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
216
  interactive_state["positive_click_times"],
217
  interactive_state["negative_click_times"]))
218
+
219
  #### shanggao code for mask save
220
  if interactive_state["mask_save"]:
221
  if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
239
  output_path (str): The path to save the generated video.
240
  fps (int, optional): The frame rate of the output video. Defaults to 30.
241
  """
242
+ # height, width, layers = frames[0].shape
243
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
244
+ # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
245
+ # print(output_path)
246
+ # for frame in frames:
247
+ # video.write(frame)
248
+
249
+ # video.release()
250
  frames = torch.from_numpy(np.asarray(frames))
251
  if not os.path.exists(os.path.dirname(output_path)):
252
  os.makedirs(os.path.dirname(output_path))
264
 
265
  # args, defined in track_anything.py
266
  args = parse_augment()
267
+ # args.port = 12315
268
+ # args.device = "cuda:1"
269
  # args.mask_save = True
270
 
271
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
279
  "inference_times": 0,
280
  "negative_click_times" : 0,
281
  "positive_click_times": 0,
282
+ "mask_save": args.mask_save,
283
+ "multi_mask": {
284
+ "mask_names": [],
285
+ "masks": []
286
+ },
287
+ "track_end_number": None
288
+ }
289
+ )
290
+
291
  video_state = gr.State(
292
  {
293
  "video_name": "",
303
  with gr.Row():
304
 
305
  # for user video input
306
+ with gr.Column():
307
+ with gr.Row(scale=0.4):
308
+ video_input = gr.Video(autosize=True)
309
+ video_info = gr.Textbox()
310
 
311
 
312
 
313
+ with gr.Row():
314
  # put the template frame under the radio button
315
+ with gr.Column():
316
  # extract frames
317
  with gr.Column():
318
  extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
319
 
320
  # click points settins, negative or positive, mode continuous or single
321
  with gr.Row():
322
+ with gr.Row():
323
  point_prompt = gr.Radio(
324
  choices=["Positive", "Negative"],
325
  value="Positive",
326
  label="Point Prompt",
327
+ interactive=True,
328
+ visible=False)
329
  click_mode = gr.Radio(
330
  choices=["Continuous", "Single"],
331
  value="Continuous",
332
  label="Clicking Mode",
333
+ interactive=True,
334
+ visible=False)
335
+ with gr.Row():
336
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
337
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
338
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
339
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
340
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
 
341
 
342
+ with gr.Column():
343
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
344
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
345
+ video_output = gr.Video(autosize=True, visible=False).style(height=360)
346
+ tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
347
 
348
  # first step: get the video information
349
  extract_frames_button.click(
351
  inputs=[
352
  video_input, video_state
353
  ],
354
+ outputs=[video_state, video_info, template_frame,
355
+ image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
356
+ tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
357
  )
358
 
359
  # second step: select images from slider
360
  image_selection_slider.release(fn=select_template,
361
+ inputs=[image_selection_slider, video_state, interactive_state],
362
+ outputs=[template_frame, video_state, interactive_state], api_name="select_image")
363
+ track_pause_number_slider.release(fn=get_end_number,
364
+ inputs=[track_pause_number_slider, interactive_state],
365
+ outputs=[interactive_state], api_name="end_image")
366
 
367
+ # click select image to get mask using sam
368
  template_frame.select(
369
  fn=sam_refine,
370
  inputs=[video_state, point_prompt, click_state, interactive_state],
371
  outputs=[template_frame, video_state, interactive_state]
372
  )
373
 
374
+ # add different mask
375
+ Add_mask_button.click(
376
+ fn=add_multi_mask,
377
+ inputs=[video_state, interactive_state, mask_dropdown],
378
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state]
379
+ )
380
+
381
+ remove_mask_button.click(
382
+ fn=remove_multi_mask,
383
+ inputs=[interactive_state],
384
+ outputs=[interactive_state]
385
+ )
386
+
387
+ # tracking video from select image and mask
388
  tracking_video_predict_button.click(
389
  fn=vos_tracking_video,
390
+ inputs=[video_state, interactive_state, mask_dropdown],
391
  outputs=[video_output, video_state, interactive_state]
392
  )
393
 
394
+ # click to get mask
395
+ mask_dropdown.change(
396
+ fn=show_mask,
397
+ inputs=[video_state, interactive_state, mask_dropdown],
398
+ outputs=[template_frame]
399
+ )
400
 
401
  # clear input
402
  video_input.clear(
413
  "inference_times": 0,
414
  "negative_click_times" : 0,
415
  "positive_click_times": 0,
416
+ "mask_save": args.mask_save,
417
+ "multi_mask": {
418
+ "mask_names": [],
419
+ "masks": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  },
421
+ "track_end_number": 0
 
 
 
 
422
  },
423
+ [[],[]],
424
+ None,
425
+ None,
426
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
427
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
428
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
429
+
430
+ ),
431
  [],
432
  [
433
  video_state,
434
  interactive_state,
435
  click_state,
436
+ video_output,
437
+ template_frame,
438
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
439
+ Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
440
  ],
 
441
  queue=False,
442
+ show_progress=False)
443
+
444
+ # points clear
445
+ clear_button_click.click(
446
+ fn = clear_click,
447
+ inputs = [video_state, click_state,],
448
+ outputs = [template_frame,click_state],
449
 
 
 
 
 
 
 
 
450
  )
451
  iface.queue(concurrency_count=1)
452
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
453
 
454
 
455
 
app_test.py CHANGED
@@ -1,23 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def update_iframe(slider_value):
4
- return f'''
5
- <script>
6
- window.addEventListener('message', function(event) {{
7
- if (event.data.sliderValue !== undefined) {{
8
- var iframe = document.getElementById("text_iframe");
9
- iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
10
- }}
11
- }}, false);
12
- </script>
13
- <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
14
- '''
15
-
16
- iface = gr.Interface(
17
- fn=update_iframe,
18
- inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
19
- outputs=gr.outputs.HTML(),
20
- allow_flagging=False,
21
- )
22
-
23
- iface.launch(server_name='0.0.0.0', server_port=12212)
1
+ # import gradio as gr
2
+
3
+ # def update_iframe(slider_value):
4
+ # return f'''
5
+ # <script>
6
+ # window.addEventListener('message', function(event) {{
7
+ # if (event.data.sliderValue !== undefined) {{
8
+ # var iframe = document.getElementById("text_iframe");
9
+ # iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
10
+ # }}
11
+ # }}, false);
12
+ # </script>
13
+ # <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
14
+ # '''
15
+
16
+ # iface = gr.Interface(
17
+ # fn=update_iframe,
18
+ # inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
19
+ # outputs=gr.outputs.HTML(),
20
+ # allow_flagging=False,
21
+ # )
22
+
23
+ # iface.launch(server_name='0.0.0.0', server_port=12212)
24
+
25
  import gradio as gr
26
 
27
+
28
+ def change_mask(drop):
29
+ return gr.update(choices=["hello", "kitty"])
30
+
31
+ with gr.Blocks() as iface:
32
+ drop = gr.Dropdown(
33
+ choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
34
+ )
35
+ radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
36
+ multi_drop = gr.Dropdown(
37
+ ["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
38
+ )
39
+
40
+ multi_drop.change(
41
+ fn=change_mask,
42
+ inputs = multi_drop,
43
+ outputs=multi_drop
44
+ )
45
+
46
+ iface.launch(server_name='0.0.0.0', server_port=1223)
 
test.txt ADDED
File without changes
test_beta.txt ADDED
File without changes
test_sample/test-sample1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:403b711376a79026beedb7d0d919d35268298150120438a22a5330d0c8cdd6b6
3
+ size 6039223
tools/interact_tools.py CHANGED
@@ -37,16 +37,16 @@ class SamControler():
37
  self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
38
 
39
 
40
- def seg_again(self, image: np.ndarray):
41
- '''
42
- it is used when interact in video
43
- '''
44
- self.sam_controler.reset_image()
45
- self.sam_controler.set_image(image)
46
- return
47
 
48
 
49
- def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
50
  '''
51
  it is used in first frame in video
52
  return: mask, logit, painted image(mask+point)
@@ -88,47 +88,47 @@ class SamControler():
88
 
89
  return mask, logit, painted_image
90
 
91
- def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
92
- origal_image = self.sam_controler.orignal_image
93
- if same:
94
- '''
95
- true; loop in the same image
96
- '''
97
- prompts = {
98
- 'point_coords': points,
99
- 'point_labels': labels,
100
- 'mask_input': logits[None, :, :]
101
- }
102
- masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
103
- mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
104
 
105
- painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
106
- painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
107
- painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
108
- painted_image = Image.fromarray(painted_image)
109
 
110
- return mask, logit, painted_image
111
- else:
112
- '''
113
- loop in the different image, interact in the video
114
- '''
115
- if image is None:
116
- raise('Image error')
117
- else:
118
- self.seg_again(image)
119
- prompts = {
120
- 'point_coords': points,
121
- 'point_labels': labels,
122
- }
123
- masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
124
- mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
125
 
126
- painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
127
- painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
128
- painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
129
- painted_image = Image.fromarray(painted_image)
130
 
131
- return mask, logit, painted_image
132
 
133
 
134
 
@@ -226,31 +226,31 @@ class SamControler():
226
 
227
 
228
 
229
- if __name__ == "__main__":
230
- points = np.array([[500, 375], [1125, 625]])
231
- labels = np.array([1, 1])
232
- image = cv2.imread('/hhd3/gaoshang/truck.jpg')
233
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
234
 
235
- sam_controler = initialize()
236
- mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
237
- painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
238
- painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
239
- cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
240
- cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
241
- painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
242
 
243
- mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
244
- painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
245
- painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
246
- cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
247
- painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
248
 
249
- mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
250
- painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
251
- painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
252
- cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
253
- painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
254
 
255
 
256
 
37
  self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
38
 
39
 
40
+ # def seg_again(self, image: np.ndarray):
41
+ # '''
42
+ # it is used when interact in video
43
+ # '''
44
+ # self.sam_controler.reset_image()
45
+ # self.sam_controler.set_image(image)
46
+ # return
47
 
48
 
49
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
50
  '''
51
  it is used in first frame in video
52
  return: mask, logit, painted image(mask+point)
88
 
89
  return mask, logit, painted_image
90
 
91
+ # def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
92
+ # origal_image = self.sam_controler.orignal_image
93
+ # if same:
94
+ # '''
95
+ # true; loop in the same image
96
+ # '''
97
+ # prompts = {
98
+ # 'point_coords': points,
99
+ # 'point_labels': labels,
100
+ # 'mask_input': logits[None, :, :]
101
+ # }
102
+ # masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
103
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
104
 
105
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
106
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
107
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
108
+ # painted_image = Image.fromarray(painted_image)
109
 
110
+ # return mask, logit, painted_image
111
+ # else:
112
+ # '''
113
+ # loop in the different image, interact in the video
114
+ # '''
115
+ # if image is None:
116
+ # raise('Image error')
117
+ # else:
118
+ # self.seg_again(image)
119
+ # prompts = {
120
+ # 'point_coords': points,
121
+ # 'point_labels': labels,
122
+ # }
123
+ # masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
124
+ # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
125
 
126
+ # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
127
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
128
+ # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
129
+ # painted_image = Image.fromarray(painted_image)
130
 
131
+ # return mask, logit, painted_image
132
 
133
 
134
 
226
 
227
 
228
 
229
+ # if __name__ == "__main__":
230
+ # points = np.array([[500, 375], [1125, 625]])
231
+ # labels = np.array([1, 1])
232
+ # image = cv2.imread('/hhd3/gaoshang/truck.jpg')
233
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
234
 
235
+ # sam_controler = initialize()
236
+ # mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
237
+ # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
238
+ # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
239
+ # cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
240
+ # cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
241
+ # painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
242
 
243
+ # mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
244
+ # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
245
+ # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
246
+ # cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
247
+ # painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
248
 
249
+ # mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
250
+ # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
251
+ # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
252
+ # cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
253
+ # painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
254
 
255
 
256
 
track_anything.py CHANGED
@@ -15,26 +15,26 @@ class TrackingAnything():
15
  self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
16
 
17
 
18
- def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
19
- same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
20
- if first_flag:
21
- mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
22
- return mask, logit, painted_image
23
 
24
- if interact_flag:
25
- mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
26
- return mask, logit, painted_image
27
 
28
- mask, logit, painted_image = self.xmem.track(image, logit)
29
- return mask, logit, painted_image
30
 
31
  def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
32
  mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
33
  return mask, logit, painted_image
34
 
35
- def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
36
- mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
37
- return mask, logit, painted_image
38
 
39
  def generator(self, images: list, template_mask:np.ndarray):
40
 
@@ -53,6 +53,7 @@ class TrackingAnything():
53
  masks.append(mask)
54
  logits.append(logit)
55
  painted_images.append(painted_image)
 
56
  return masks, logits, painted_images
57
 
58
 
15
  self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
16
 
17
 
18
+ # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
19
+ # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
20
+ # if first_flag:
21
+ # mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
22
+ # return mask, logit, painted_image
23
 
24
+ # if interact_flag:
25
+ # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
26
+ # return mask, logit, painted_image
27
 
28
+ # mask, logit, painted_image = self.xmem.track(image, logit)
29
+ # return mask, logit, painted_image
30
 
31
  def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
32
  mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
33
  return mask, logit, painted_image
34
 
35
+ # def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
36
+ # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
37
+ # return mask, logit, painted_image
38
 
39
  def generator(self, images: list, template_mask:np.ndarray):
40
 
53
  masks.append(mask)
54
  logits.append(logit)
55
  painted_images.append(painted_image)
56
+ print("tracking image {}".format(i))
57
  return masks, logits, painted_images
58
 
59
 
tracker/base_tracker.py CHANGED
@@ -67,6 +67,7 @@ class BaseTracker:
67
  logit: numpy arrays, probability map (H, W)
68
  painted_image: numpy array (H, W, 3)
69
  """
 
70
  if first_frame_annotation is not None: # first frame mask
71
  # initialisation
72
  mask, labels = self.mapper.convert_mask(first_frame_annotation)
@@ -87,12 +88,20 @@ class BaseTracker:
87
  out_mask = torch.argmax(probs, dim=0)
88
  out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
89
 
90
- num_objs = out_mask.max()
 
 
 
 
 
 
91
  painted_image = frame
92
  for obj in range(1, num_objs+1):
93
- painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
94
-
95
- return out_mask, out_mask, painted_image
 
 
96
 
97
  @torch.no_grad()
98
  def sam_refinement(self, frame, logits, ti):
@@ -142,34 +151,38 @@ if __name__ == '__main__':
142
  # sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
143
  tracker = BaseTracker(XMEM_checkpoint, device, None, device)
144
 
145
- # test for storage efficiency
146
- frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
147
- first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
148
 
149
- for ti, frame in enumerate(frames):
150
- print(ti)
151
- if ti > 200:
152
- break
153
- if ti == 0:
154
- mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
155
- else:
156
- mask, prob, painted_image = tracker.track(frame)
157
- # save
158
- painted_image = Image.fromarray(painted_image)
159
- painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
160
 
161
- tracker.clear_memory()
162
  for ti, frame in enumerate(frames):
163
- print(ti)
164
- # if ti > 200:
165
- # break
166
  if ti == 0:
167
  mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
168
  else:
169
  mask, prob, painted_image = tracker.track(frame)
170
  # save
171
  painted_image = Image.fromarray(painted_image)
172
- painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # # track anything given in the first frame annotation
175
  # for ti, frame in enumerate(frames):
67
  logit: numpy arrays, probability map (H, W)
68
  painted_image: numpy array (H, W, 3)
69
  """
70
+
71
  if first_frame_annotation is not None: # first frame mask
72
  # initialisation
73
  mask, labels = self.mapper.convert_mask(first_frame_annotation)
88
  out_mask = torch.argmax(probs, dim=0)
89
  out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
90
 
91
+ final_mask = np.zeros_like(out_mask)
92
+
93
+ # map back
94
+ for k, v in self.mapper.remappings.items():
95
+ final_mask[out_mask == v] = k
96
+
97
+ num_objs = final_mask.max()
98
  painted_image = frame
99
  for obj in range(1, num_objs+1):
100
+ if np.max(final_mask==obj) == 0:
101
+ continue
102
+ painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
103
+
104
+ return final_mask, final_mask, painted_image
105
 
106
  @torch.no_grad()
107
  def sam_refinement(self, frame, logits, ti):
151
  # sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
152
  tracker = BaseTracker(XMEM_checkpoint, device, None, device)
153
 
154
+ # # test for storage efficiency
155
+ # frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
156
+ # first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
157
 
158
+ first_frame_annotation[first_frame_annotation==1] = 15
159
+ first_frame_annotation[first_frame_annotation==2] = 20
160
+
161
+ save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
162
+ if not os.path.exists(save_path):
163
+ os.mkdir(save_path)
 
 
 
 
 
164
 
 
165
  for ti, frame in enumerate(frames):
 
 
 
166
  if ti == 0:
167
  mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
168
  else:
169
  mask, prob, painted_image = tracker.track(frame)
170
  # save
171
  painted_image = Image.fromarray(painted_image)
172
+ painted_image.save(f'{save_path}/{ti:05d}.png')
173
+
174
+ # tracker.clear_memory()
175
+ # for ti, frame in enumerate(frames):
176
+ # print(ti)
177
+ # # if ti > 200:
178
+ # # break
179
+ # if ti == 0:
180
+ # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
181
+ # else:
182
+ # mask, prob, painted_image = tracker.track(frame)
183
+ # # save
184
+ # painted_image = Image.fromarray(painted_image)
185
+ # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
186
 
187
  # # track anything given in the first frame annotation
188
  # for ti, frame in enumerate(frames):