File size: 1,590 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import numpy as np
import mmcv
from mmengine.visualization import Visualizer

from third_parts.sam2.build_sam import build_sam2
from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor
from mmdet.structures.mask import bitmap_to_polygon

IMG_PATH = 'assets/view.jpg'
MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt"
MODEL_CFG = "sam2_hiera_l.yaml"


def prepare():
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


if __name__ == '__main__':
    prepare()
    sam2_model = build_sam2(MODEL_CFG, MODEL_CKPT, device="cuda")
    predictor = SAM2ImagePredictor(sam2_model)

    image = mmcv.imread(IMG_PATH)
    predictor.set_image(image)
    input_point = np.array([[500, 475]])
    input_label = np.array([1])

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]
    scores = scores[sorted_ind]
    logits = logits[sorted_ind]


    visualizer = Visualizer(image=image)
    masks = masks.astype(bool)
    masks = masks[0:1]
    polygons = []
    for i, mask in enumerate(masks):
        contours, _ = bitmap_to_polygon(mask)
        polygons.extend(contours)
    visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8)
    visualizer.draw_binary_masks(masks, alphas=0.8)

    visualizer.draw_points(input_point, 'r', marker='*')

    result = visualizer.get_image()