DenseLabelDev / tools /sam2 /video_predictor.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import torch
import numpy as np
import mmcv
import mmengine
from mmengine.visualization import Visualizer
from third_parts.sam2.build_sam import build_sam2_video_predictor
from mmdet.structures.mask import bitmap_to_polygon
VID_PATH = 'assets/vid_view'
MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt"
MODEL_CFG = "sam2_hiera_l.yaml"
def prepare():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if __name__ == '__main__':
prepare()
predictor = build_sam2_video_predictor(MODEL_CFG, MODEL_CKPT)
inference_state = predictor.init_state(video_path=VID_PATH)
input_point = np.array([[255, 475]])
input_label = np.array([1])
ann_frame_idx = 0
ann_obj_id = 1
_frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=input_point,
labels=input_label,
)
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# Visualization
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(VID_PATH)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
mmengine.mkdir_or_exist("./result")
for idx in range(len(frame_names)):
image = mmcv.imread(os.path.join(VID_PATH, frame_names[idx]))
visualizer = Visualizer(image=image)
masks = video_segments[idx]
polygons = []
vis_masks = []
for i, mask in masks.items():
contours, _ = bitmap_to_polygon(mask[0])
polygons.extend(contours)
vis_masks.append(mask[0])
visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8)
visualizer.draw_binary_masks(np.concatenate(vis_masks, axis=0), alphas=0.8)
# visualizer.draw_points(input_point, 'r', marker='*')
result = visualizer.get_image()
mmcv.imwrite(result, os.path.join('./result', frame_names[idx]))