File size: 1,129 Bytes
c263a47
 
 
 
7e50af9
 
 
 
 
 
 
 
 
 
c263a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import supervision as sv

from segment_anything.modeling.sam import Sam
from segment_anything import SamPredictor, SamAutomaticMaskGenerator


def sam_inference(
    image: np.ndarray,
    model: Sam
) -> sv.Detections:
    mask_generator = SamAutomaticMaskGenerator(model)
    result = mask_generator.generate(image=image)
    return sv.Detections.from_sam(result)


def sam_interactive_inference(
    image: np.ndarray,
    mask: np.ndarray,
    model: Sam
) -> sv.Detections:
    predictor = SamPredictor(model)
    predictor.set_image(image)
    masks = []
    for polygon in sv.mask_to_polygons(mask.astype(bool)):
        random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
        input_point = polygon[random_point_indexes]
        input_label = np.ones(5)
        mask = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )[0][0]
        masks.append(mask)
    masks = np.array(masks, dtype=bool)
    return sv.Detections(
        xyxy=sv.mask_to_xyxy(masks),
        mask=masks
    )