alex commited on
Commit
fd0980c
·
1 Parent(s): c433e00

bounding box fix

Browse files
app.py CHANGED
@@ -549,10 +549,11 @@ with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean
549
  action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
550
 
551
  with gr.Accordion("Preprocessed Data", open=False, visible=True):
552
- pose_video = gr.Video(label="Pose Video")
553
- bg_video = gr.Video(label="Background Video")
554
- face_video = gr.Video(label="Face Video")
555
- mask_video = gr.Video(label="Mask Video")
 
556
 
557
  with gr.Row():
558
  with gr.Column(elem_id="col-showcase"):
 
549
  action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
550
 
551
  with gr.Accordion("Preprocessed Data", open=False, visible=True):
552
+ with gr.Row():
553
+ pose_video = gr.Video(label="Pose Video")
554
+ bg_video = gr.Video(label="Background Video")
555
+ face_video = gr.Video(label="Face Video")
556
+ mask_video = gr.Video(label="Mask Video")
557
 
558
  with gr.Row():
559
  with gr.Column(elem_id="col-showcase"):
wan/modules/animate/preprocess/process_pipepline.py CHANGED
@@ -94,7 +94,7 @@ class ProcessPipeline():
94
  canvas = np.zeros_like(refer_img)
95
  conditioning_image = draw_aapose_by_meta_new(canvas, meta)
96
  cond_images.append(conditioning_image)
97
- masks = self.get_mask(frames, 400, tpl_pose_metas)
98
 
99
  bg_images = []
100
  aug_masks = []
@@ -352,3 +352,180 @@ class ProcessPipeline():
352
  metas_list.append(meta)
353
  return metas_list
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  canvas = np.zeros_like(refer_img)
95
  conditioning_image = draw_aapose_by_meta_new(canvas, meta)
96
  cond_images.append(conditioning_image)
97
+ masks = self.get_mask_from_face_bbox(frames, 400, tpl_pose_metas)
98
 
99
  bg_images = []
100
  aug_masks = []
 
352
  metas_list.append(meta)
353
  return metas_list
354
 
355
+ def get_mask_from_face_bbox(self, frames, th_step, kp2ds_all):
356
+ """
357
+ Build masks using a face bounding box per key frame (derived from keypoints_face),
358
+ then propagate with SAM2 across each chunk of frames.
359
+ """
360
+ H, W = frames[0].shape[:2]
361
+
362
+ def _clip_box(x1, y1, x2, y2, W, H):
363
+ x1 = max(0, min(int(x1), W - 1))
364
+ x2 = max(0, min(int(x2), W - 1))
365
+ y1 = max(0, min(int(y1), H - 1))
366
+ y2 = max(0, min(int(y2), H - 1))
367
+ if x2 <= x1: x2 = min(W - 1, x1 + 1)
368
+ if y2 <= y1: y2 = min(H - 1, y1 + 1)
369
+ return x1, y1, x2, y2
370
+
371
+ frame_num = len(frames)
372
+ if frame_num < th_step:
373
+ num_step = 1
374
+ else:
375
+ num_step = (frame_num + th_step) // th_step
376
+
377
+ all_mask = []
378
+
379
+ for step_idx in range(num_step):
380
+ each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
381
+ kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
382
+ if len(each_frames) == 0:
383
+ continue
384
+
385
+ # pick a few key frames in this chunk
386
+ key_frame_num = 4 if len(each_frames) > 4 else 1
387
+ key_frame_step = max(1, len(kp2ds) // key_frame_num)
388
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
389
+
390
+ # compute face boxes on the selected key frames
391
+ key_frame_boxes = []
392
+ for kfi in key_frame_index_list:
393
+ meta = kp2ds[kfi]
394
+ # get_face_bboxes returns (x1, x2, y1, y2) in your code
395
+ x1, x2, y1, y2 = get_face_bboxes(
396
+ meta['keypoints_face'][:, :2],
397
+ scale=1.3,
398
+ image_shape=(H, W)
399
+ )
400
+ x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2, W, H)
401
+ key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
402
+
403
+ # init SAM2 for this chunk
404
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
405
+ self.predictor.reset_state(inference_state)
406
+ ann_obj_id = 1
407
+
408
+ # seed with box prompts (preferred), else fall back to points
409
+ for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
410
+ used_box = False
411
+ try:
412
+ # If your predictor exposes a box API, this is ideal.
413
+ _ = self.predictor.add_new_box(
414
+ inference_state=inference_state,
415
+ frame_idx=ann_frame_idx,
416
+ obj_id=ann_obj_id,
417
+ box=box_xyxy[None, :] # shape (1, 4)
418
+ )
419
+ used_box = True
420
+ except Exception:
421
+ used_box = False
422
+
423
+ if not used_box:
424
+ # Fallback: sample a few positive points inside the box
425
+ x1, y1, x2, y2 = box_xyxy.astype(int)
426
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
427
+ pts = np.array([
428
+ [cx, cy],
429
+ [x1 + (x2 - x1) // 4, cy],
430
+ [x2 - (x2 - x1) // 4, cy],
431
+ [cx, y1 + (y2 - y1) // 4],
432
+ [cx, y2 - (y2 - y1) // 4],
433
+ ], dtype=np.int32)
434
+ labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
435
+ _ = self.predictor.add_new_points(
436
+ inference_state=inference_state,
437
+ frame_idx=ann_frame_idx,
438
+ obj_id=ann_obj_id,
439
+ points=pts,
440
+ labels=labels,
441
+ )
442
+
443
+ # propagate across the chunk
444
+ video_segments = {}
445
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
446
+ video_segments[out_frame_idx] = {
447
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
448
+ for i, out_obj_id in enumerate(out_obj_ids)
449
+ }
450
+
451
+ # collect masks (single object id)
452
+ for out_frame_idx in range(len(video_segments)):
453
+ # (H, W) boolean/uint8
454
+ mask = next(iter(video_segments[out_frame_idx].values()))
455
+ mask = mask[0].astype(np.uint8)
456
+ all_mask.append(mask)
457
+
458
+ return all_mask
459
+ def get_mask_from_face_point(self, frames, th_step, kp2ds_all):
460
+ """
461
+ Build masks using a single face *center point* per key frame,
462
+ then propagate with SAM2 across each chunk of frames.
463
+ """
464
+ H, W = frames[0].shape[:2]
465
+
466
+ frame_num = len(frames)
467
+ num_step = 1 if frame_num < th_step else (frame_num + th_step) // th_step
468
+
469
+ all_mask = []
470
+
471
+ for step_idx in range(num_step):
472
+ each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
473
+ kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
474
+ if len(each_frames) == 0:
475
+ continue
476
+
477
+ # choose a few key frames to seed the object
478
+ key_frame_num = 4 if len(each_frames) > 4 else 1
479
+ key_frame_step = max(1, len(kp2ds) // key_frame_num)
480
+ key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
481
+
482
+ # compute center point from face bbox for each selected key frame
483
+ center_pts = []
484
+ for kfi in key_frame_index_list:
485
+ meta = kp2ds[kfi]
486
+ # your helper returns (x1, x2, y1, y2)
487
+ x1, x2, y1, y2 = get_face_bboxes(
488
+ meta['keypoints_face'][:, :2],
489
+ scale=1.3,
490
+ image_shape=(H, W)
491
+ )
492
+ cx = (x1 + x2) // 2
493
+ cy = (y1 + y2) // 2
494
+ # clip just in case
495
+ cx = int(max(0, min(cx, W - 1)))
496
+ cy = int(max(0, min(cy, H - 1)))
497
+ center_pts.append(np.array([cx, cy], dtype=np.int32))
498
+
499
+ # init SAM2 for this chunk
500
+ inference_state = self.predictor.init_state_v2(frames=each_frames)
501
+ self.predictor.reset_state(inference_state)
502
+ ann_obj_id = 1
503
+
504
+ # seed each key frame with a single positive point at the face center
505
+ for ann_frame_idx, pt in zip(key_frame_index_list, center_pts):
506
+ pts = pt[None, :] # shape (1, 2)
507
+ labels = np.ones(1, dtype=np.int32) # 1 = positive
508
+ _ = self.predictor.add_new_points(
509
+ inference_state=inference_state,
510
+ frame_idx=ann_frame_idx,
511
+ obj_id=ann_obj_id,
512
+ points=pts,
513
+ labels=labels,
514
+ )
515
+
516
+ # propagate across the chunk
517
+ video_segments = {}
518
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
519
+ video_segments[out_frame_idx] = {
520
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
521
+ for i, out_obj_id in enumerate(out_obj_ids)
522
+ }
523
+
524
+ # collect masks (single object id)
525
+ for out_frame_idx in range(len(video_segments)):
526
+ mask = next(iter(video_segments[out_frame_idx].values()))
527
+ mask = mask[0].astype(np.uint8)
528
+ all_mask.append(mask)
529
+
530
+ return all_mask
531
+