chongzhou commited on
Commit
cf4b18a
·
1 Parent(s): 238c545

move segment_with_points to CPU

Browse files
Files changed (2) hide show
  1. app.py +56 -77
  2. sam2/sam2_video_predictor.py +1 -1
app.py CHANGED
@@ -246,7 +246,6 @@ def preprocess_video_in(
246
  ]
247
 
248
 
249
- @spaces.GPU(duration=5)
250
  def segment_with_points(
251
  point_type,
252
  first_frame,
@@ -256,68 +255,64 @@ def segment_with_points(
256
  inference_state,
257
  evt: gr.SelectData,
258
  ):
259
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
260
- if torch.cuda.get_device_properties(0).major >= 8:
261
- torch.backends.cuda.matmul.allow_tf32 = True
262
- torch.backends.cudnn.allow_tf32 = True
263
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
264
- input_points.append(evt.index)
265
- print(f"TRACKING INPUT POINT: {input_points}")
266
-
267
- if point_type == "include":
268
- input_labels.append(1)
269
- elif point_type == "exclude":
270
- input_labels.append(0)
271
- print(f"TRACKING INPUT LABEL: {input_labels}")
272
-
273
- # Open the image and get its dimensions
274
- transparent_background = Image.fromarray(first_frame).convert("RGBA")
275
- w, h = transparent_background.size
276
-
277
- # Define the circle radius as a fraction of the smaller dimension
278
- fraction = 0.01 # You can adjust this value as needed
279
- radius = int(fraction * min(w, h))
280
-
281
- # Create a transparent layer to draw on
282
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
283
-
284
- for index, track in enumerate(input_points):
285
- if input_labels[index] == 1:
286
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
287
- else:
288
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
289
-
290
- # Convert the transparent layer back to an image
291
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
292
- selected_point_map = Image.alpha_composite(
293
- transparent_background, transparent_layer
294
- )
295
 
296
- # Let's add a positive click at (x, y) = (210, 350) to get started
297
- points = np.array(input_points, dtype=np.float32)
298
- # for labels, `1` means positive click and `0` means negative click
299
- labels = np.array(input_labels, dtype=np.int32)
300
- _, _, out_mask_logits = predictor.add_new_points(
301
- inference_state=inference_state,
302
- frame_idx=0,
303
- obj_id=OBJ_ID,
304
- points=points,
305
- labels=labels,
306
- )
307
 
308
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
309
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
310
 
311
- torch.cuda.empty_cache()
312
- return (
313
- selected_point_map,
314
- first_frame_output,
315
- first_frame,
316
- all_frames,
317
- input_points,
318
- input_labels,
319
- inference_state,
320
- )
321
 
322
 
323
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -338,10 +333,8 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
338
  @spaces.GPU(duration=30)
339
  def propagate_to_all(
340
  video_in,
341
- first_frame,
342
  all_frames,
343
  input_points,
344
- input_labels,
345
  inference_state,
346
  ):
347
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
@@ -394,14 +387,7 @@ def propagate_to_all(
394
  # Write the result to a file
395
  clip.write_videofile(final_vid_output_path, codec="libx264")
396
 
397
- return (
398
- gr.update(value=final_vid_output_path),
399
- first_frame,
400
- all_frames,
401
- input_points,
402
- input_labels,
403
- inference_state,
404
- )
405
 
406
 
407
  def update_ui():
@@ -586,19 +572,12 @@ with gr.Blocks() as demo:
586
  fn=propagate_to_all,
587
  inputs=[
588
  video_in,
589
- first_frame,
590
  all_frames,
591
  input_points,
592
- input_labels,
593
  inference_state,
594
  ],
595
  outputs=[
596
  output_video,
597
- first_frame,
598
- all_frames,
599
- input_points,
600
- input_labels,
601
- inference_state,
602
  ],
603
  concurrency_limit=10,
604
  queue=False,
 
246
  ]
247
 
248
 
 
249
  def segment_with_points(
250
  point_type,
251
  first_frame,
 
255
  inference_state,
256
  evt: gr.SelectData,
257
  ):
258
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
259
+ input_points.append(evt.index)
260
+ print(f"TRACKING INPUT POINT: {input_points}")
261
+
262
+ if point_type == "include":
263
+ input_labels.append(1)
264
+ elif point_type == "exclude":
265
+ input_labels.append(0)
266
+ print(f"TRACKING INPUT LABEL: {input_labels}")
267
+
268
+ # Open the image and get its dimensions
269
+ transparent_background = Image.fromarray(first_frame).convert("RGBA")
270
+ w, h = transparent_background.size
271
+
272
+ # Define the circle radius as a fraction of the smaller dimension
273
+ fraction = 0.01 # You can adjust this value as needed
274
+ radius = int(fraction * min(w, h))
275
+
276
+ # Create a transparent layer to draw on
277
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
278
+
279
+ for index, track in enumerate(input_points):
280
+ if input_labels[index] == 1:
281
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
282
+ else:
283
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
284
+
285
+ # Convert the transparent layer back to an image
286
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
287
+ selected_point_map = Image.alpha_composite(
288
+ transparent_background, transparent_layer
289
+ )
 
 
 
 
290
 
291
+ # Let's add a positive click at (x, y) = (210, 350) to get started
292
+ points = np.array(input_points, dtype=np.float32)
293
+ # for labels, `1` means positive click and `0` means negative click
294
+ labels = np.array(input_labels, dtype=np.int32)
295
+ _, _, out_mask_logits = predictor.add_new_points(
296
+ inference_state=inference_state,
297
+ frame_idx=0,
298
+ obj_id=OBJ_ID,
299
+ points=points,
300
+ labels=labels,
301
+ )
302
 
303
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
304
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
305
 
306
+ torch.cuda.empty_cache()
307
+ return (
308
+ selected_point_map,
309
+ first_frame_output,
310
+ first_frame,
311
+ all_frames,
312
+ input_points,
313
+ input_labels,
314
+ inference_state,
315
+ )
316
 
317
 
318
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
333
  @spaces.GPU(duration=30)
334
  def propagate_to_all(
335
  video_in,
 
336
  all_frames,
337
  input_points,
 
338
  inference_state,
339
  ):
340
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
 
387
  # Write the result to a file
388
  clip.write_videofile(final_vid_output_path, codec="libx264")
389
 
390
+ return gr.update(value=final_vid_output_path)
 
 
 
 
 
 
 
391
 
392
 
393
  def update_ui():
 
572
  fn=propagate_to_all,
573
  inputs=[
574
  video_in,
 
575
  all_frames,
576
  input_points,
 
577
  inference_state,
578
  ],
579
  outputs=[
580
  output_video,
 
 
 
 
 
581
  ],
582
  concurrency_limit=10,
583
  queue=False,
sam2/sam2_video_predictor.py CHANGED
@@ -107,7 +107,7 @@ class SAM2VideoPredictor(SAM2Base):
107
  inference_state["tracking_has_started"] = False
108
  inference_state["frames_already_tracked"] = {}
109
  # Warm up the visual backbone and cache the image feature on frame 0
110
- # self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
  return inference_state
112
 
113
  @classmethod
 
107
  inference_state["tracking_has_started"] = False
108
  inference_state["frames_already_tracked"] = {}
109
  # Warm up the visual backbone and cache the image feature on frame 0
110
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
  return inference_state
112
 
113
  @classmethod