File size: 2,476 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os

import torch
import numpy as np
import mmcv
import mmengine
from mmengine.visualization import Visualizer

from third_parts.sam2.build_sam import build_sam2_video_predictor
from mmdet.structures.mask import bitmap_to_polygon

VID_PATH = 'assets/vid_view'
MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt"
MODEL_CFG = "sam2_hiera_l.yaml"


def prepare():
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


if __name__ == '__main__':
    prepare()
    predictor = build_sam2_video_predictor(MODEL_CFG, MODEL_CKPT)
    inference_state = predictor.init_state(video_path=VID_PATH)

    input_point = np.array([[255, 475]])
    input_label = np.array([1])

    ann_frame_idx = 0
    ann_obj_id = 1

    _frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=input_point,
        labels=input_label,
    )

    video_segments = {}  # video_segments contains the per-frame segmentation results
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }



    # Visualization
    # scan all the JPEG frame names in this directory
    frame_names = [
        p for p in os.listdir(VID_PATH)
        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
    ]
    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

    mmengine.mkdir_or_exist("./result")
    for idx in range(len(frame_names)):
        image = mmcv.imread(os.path.join(VID_PATH, frame_names[idx]))
        visualizer = Visualizer(image=image)
        masks = video_segments[idx]
        polygons = []
        vis_masks = []
        for i, mask in masks.items():
            contours, _ = bitmap_to_polygon(mask[0])
            polygons.extend(contours)

            vis_masks.append(mask[0])
        visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8)
        visualizer.draw_binary_masks(np.concatenate(vis_masks, axis=0), alphas=0.8)

        # visualizer.draw_points(input_point, 'r', marker='*')

        result = visualizer.get_image()
        mmcv.imwrite(result, os.path.join('./result', frame_names[idx]))