Spaces:
Running
on
Zero
Running
on
Zero
alex
commited on
Commit
·
fd0980c
1
Parent(s):
c433e00
bounding box fix
Browse files- app.py +5 -4
- wan/modules/animate/preprocess/process_pipepline.py +178 -1
app.py
CHANGED
|
@@ -549,10 +549,11 @@ with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean
|
|
| 549 |
action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
|
| 550 |
|
| 551 |
with gr.Accordion("Preprocessed Data", open=False, visible=True):
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
|
|
|
| 556 |
|
| 557 |
with gr.Row():
|
| 558 |
with gr.Column(elem_id="col-showcase"):
|
|
|
|
| 549 |
action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
|
| 550 |
|
| 551 |
with gr.Accordion("Preprocessed Data", open=False, visible=True):
|
| 552 |
+
with gr.Row():
|
| 553 |
+
pose_video = gr.Video(label="Pose Video")
|
| 554 |
+
bg_video = gr.Video(label="Background Video")
|
| 555 |
+
face_video = gr.Video(label="Face Video")
|
| 556 |
+
mask_video = gr.Video(label="Mask Video")
|
| 557 |
|
| 558 |
with gr.Row():
|
| 559 |
with gr.Column(elem_id="col-showcase"):
|
wan/modules/animate/preprocess/process_pipepline.py
CHANGED
|
@@ -94,7 +94,7 @@ class ProcessPipeline():
|
|
| 94 |
canvas = np.zeros_like(refer_img)
|
| 95 |
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
|
| 96 |
cond_images.append(conditioning_image)
|
| 97 |
-
masks = self.
|
| 98 |
|
| 99 |
bg_images = []
|
| 100 |
aug_masks = []
|
|
@@ -352,3 +352,180 @@ class ProcessPipeline():
|
|
| 352 |
metas_list.append(meta)
|
| 353 |
return metas_list
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
canvas = np.zeros_like(refer_img)
|
| 95 |
conditioning_image = draw_aapose_by_meta_new(canvas, meta)
|
| 96 |
cond_images.append(conditioning_image)
|
| 97 |
+
masks = self.get_mask_from_face_bbox(frames, 400, tpl_pose_metas)
|
| 98 |
|
| 99 |
bg_images = []
|
| 100 |
aug_masks = []
|
|
|
|
| 352 |
metas_list.append(meta)
|
| 353 |
return metas_list
|
| 354 |
|
| 355 |
+
def get_mask_from_face_bbox(self, frames, th_step, kp2ds_all):
|
| 356 |
+
"""
|
| 357 |
+
Build masks using a face bounding box per key frame (derived from keypoints_face),
|
| 358 |
+
then propagate with SAM2 across each chunk of frames.
|
| 359 |
+
"""
|
| 360 |
+
H, W = frames[0].shape[:2]
|
| 361 |
+
|
| 362 |
+
def _clip_box(x1, y1, x2, y2, W, H):
|
| 363 |
+
x1 = max(0, min(int(x1), W - 1))
|
| 364 |
+
x2 = max(0, min(int(x2), W - 1))
|
| 365 |
+
y1 = max(0, min(int(y1), H - 1))
|
| 366 |
+
y2 = max(0, min(int(y2), H - 1))
|
| 367 |
+
if x2 <= x1: x2 = min(W - 1, x1 + 1)
|
| 368 |
+
if y2 <= y1: y2 = min(H - 1, y1 + 1)
|
| 369 |
+
return x1, y1, x2, y2
|
| 370 |
+
|
| 371 |
+
frame_num = len(frames)
|
| 372 |
+
if frame_num < th_step:
|
| 373 |
+
num_step = 1
|
| 374 |
+
else:
|
| 375 |
+
num_step = (frame_num + th_step) // th_step
|
| 376 |
+
|
| 377 |
+
all_mask = []
|
| 378 |
+
|
| 379 |
+
for step_idx in range(num_step):
|
| 380 |
+
each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
|
| 381 |
+
kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
|
| 382 |
+
if len(each_frames) == 0:
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
# pick a few key frames in this chunk
|
| 386 |
+
key_frame_num = 4 if len(each_frames) > 4 else 1
|
| 387 |
+
key_frame_step = max(1, len(kp2ds) // key_frame_num)
|
| 388 |
+
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
|
| 389 |
+
|
| 390 |
+
# compute face boxes on the selected key frames
|
| 391 |
+
key_frame_boxes = []
|
| 392 |
+
for kfi in key_frame_index_list:
|
| 393 |
+
meta = kp2ds[kfi]
|
| 394 |
+
# get_face_bboxes returns (x1, x2, y1, y2) in your code
|
| 395 |
+
x1, x2, y1, y2 = get_face_bboxes(
|
| 396 |
+
meta['keypoints_face'][:, :2],
|
| 397 |
+
scale=1.3,
|
| 398 |
+
image_shape=(H, W)
|
| 399 |
+
)
|
| 400 |
+
x1, y1, x2, y2 = _clip_box(x1, y1, x2, y2, W, H)
|
| 401 |
+
key_frame_boxes.append(np.array([x1, y1, x2, y2], dtype=np.float32))
|
| 402 |
+
|
| 403 |
+
# init SAM2 for this chunk
|
| 404 |
+
inference_state = self.predictor.init_state_v2(frames=each_frames)
|
| 405 |
+
self.predictor.reset_state(inference_state)
|
| 406 |
+
ann_obj_id = 1
|
| 407 |
+
|
| 408 |
+
# seed with box prompts (preferred), else fall back to points
|
| 409 |
+
for ann_frame_idx, box_xyxy in zip(key_frame_index_list, key_frame_boxes):
|
| 410 |
+
used_box = False
|
| 411 |
+
try:
|
| 412 |
+
# If your predictor exposes a box API, this is ideal.
|
| 413 |
+
_ = self.predictor.add_new_box(
|
| 414 |
+
inference_state=inference_state,
|
| 415 |
+
frame_idx=ann_frame_idx,
|
| 416 |
+
obj_id=ann_obj_id,
|
| 417 |
+
box=box_xyxy[None, :] # shape (1, 4)
|
| 418 |
+
)
|
| 419 |
+
used_box = True
|
| 420 |
+
except Exception:
|
| 421 |
+
used_box = False
|
| 422 |
+
|
| 423 |
+
if not used_box:
|
| 424 |
+
# Fallback: sample a few positive points inside the box
|
| 425 |
+
x1, y1, x2, y2 = box_xyxy.astype(int)
|
| 426 |
+
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
| 427 |
+
pts = np.array([
|
| 428 |
+
[cx, cy],
|
| 429 |
+
[x1 + (x2 - x1) // 4, cy],
|
| 430 |
+
[x2 - (x2 - x1) // 4, cy],
|
| 431 |
+
[cx, y1 + (y2 - y1) // 4],
|
| 432 |
+
[cx, y2 - (y2 - y1) // 4],
|
| 433 |
+
], dtype=np.int32)
|
| 434 |
+
labels = np.ones(len(pts), dtype=np.int32) # 1 = positive
|
| 435 |
+
_ = self.predictor.add_new_points(
|
| 436 |
+
inference_state=inference_state,
|
| 437 |
+
frame_idx=ann_frame_idx,
|
| 438 |
+
obj_id=ann_obj_id,
|
| 439 |
+
points=pts,
|
| 440 |
+
labels=labels,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# propagate across the chunk
|
| 444 |
+
video_segments = {}
|
| 445 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
| 446 |
+
video_segments[out_frame_idx] = {
|
| 447 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 448 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
# collect masks (single object id)
|
| 452 |
+
for out_frame_idx in range(len(video_segments)):
|
| 453 |
+
# (H, W) boolean/uint8
|
| 454 |
+
mask = next(iter(video_segments[out_frame_idx].values()))
|
| 455 |
+
mask = mask[0].astype(np.uint8)
|
| 456 |
+
all_mask.append(mask)
|
| 457 |
+
|
| 458 |
+
return all_mask
|
| 459 |
+
def get_mask_from_face_point(self, frames, th_step, kp2ds_all):
|
| 460 |
+
"""
|
| 461 |
+
Build masks using a single face *center point* per key frame,
|
| 462 |
+
then propagate with SAM2 across each chunk of frames.
|
| 463 |
+
"""
|
| 464 |
+
H, W = frames[0].shape[:2]
|
| 465 |
+
|
| 466 |
+
frame_num = len(frames)
|
| 467 |
+
num_step = 1 if frame_num < th_step else (frame_num + th_step) // th_step
|
| 468 |
+
|
| 469 |
+
all_mask = []
|
| 470 |
+
|
| 471 |
+
for step_idx in range(num_step):
|
| 472 |
+
each_frames = frames[step_idx * th_step:(step_idx + 1) * th_step]
|
| 473 |
+
kp2ds = kp2ds_all[step_idx * th_step:(step_idx + 1) * th_step]
|
| 474 |
+
if len(each_frames) == 0:
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
# choose a few key frames to seed the object
|
| 478 |
+
key_frame_num = 4 if len(each_frames) > 4 else 1
|
| 479 |
+
key_frame_step = max(1, len(kp2ds) // key_frame_num)
|
| 480 |
+
key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))[:key_frame_num]
|
| 481 |
+
|
| 482 |
+
# compute center point from face bbox for each selected key frame
|
| 483 |
+
center_pts = []
|
| 484 |
+
for kfi in key_frame_index_list:
|
| 485 |
+
meta = kp2ds[kfi]
|
| 486 |
+
# your helper returns (x1, x2, y1, y2)
|
| 487 |
+
x1, x2, y1, y2 = get_face_bboxes(
|
| 488 |
+
meta['keypoints_face'][:, :2],
|
| 489 |
+
scale=1.3,
|
| 490 |
+
image_shape=(H, W)
|
| 491 |
+
)
|
| 492 |
+
cx = (x1 + x2) // 2
|
| 493 |
+
cy = (y1 + y2) // 2
|
| 494 |
+
# clip just in case
|
| 495 |
+
cx = int(max(0, min(cx, W - 1)))
|
| 496 |
+
cy = int(max(0, min(cy, H - 1)))
|
| 497 |
+
center_pts.append(np.array([cx, cy], dtype=np.int32))
|
| 498 |
+
|
| 499 |
+
# init SAM2 for this chunk
|
| 500 |
+
inference_state = self.predictor.init_state_v2(frames=each_frames)
|
| 501 |
+
self.predictor.reset_state(inference_state)
|
| 502 |
+
ann_obj_id = 1
|
| 503 |
+
|
| 504 |
+
# seed each key frame with a single positive point at the face center
|
| 505 |
+
for ann_frame_idx, pt in zip(key_frame_index_list, center_pts):
|
| 506 |
+
pts = pt[None, :] # shape (1, 2)
|
| 507 |
+
labels = np.ones(1, dtype=np.int32) # 1 = positive
|
| 508 |
+
_ = self.predictor.add_new_points(
|
| 509 |
+
inference_state=inference_state,
|
| 510 |
+
frame_idx=ann_frame_idx,
|
| 511 |
+
obj_id=ann_obj_id,
|
| 512 |
+
points=pts,
|
| 513 |
+
labels=labels,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# propagate across the chunk
|
| 517 |
+
video_segments = {}
|
| 518 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
| 519 |
+
video_segments[out_frame_idx] = {
|
| 520 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 521 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
# collect masks (single object id)
|
| 525 |
+
for out_frame_idx in range(len(video_segments)):
|
| 526 |
+
mask = next(iter(video_segments[out_frame_idx].values()))
|
| 527 |
+
mask = mask[0].astype(np.uint8)
|
| 528 |
+
all_mask.append(mask)
|
| 529 |
+
|
| 530 |
+
return all_mask
|
| 531 |
+
|