SoM / sam_utils.py
SkalskiP's picture
Updated the 'sam_utils.py' and 'app.py' modules to implement automated mask generation, result highlighting and mark generation functionalities.
7e50af9
raw
history blame
1.13 kB
import numpy as np
import supervision as sv
from segment_anything.modeling.sam import Sam
from segment_anything import SamPredictor, SamAutomaticMaskGenerator
def sam_inference(
image: np.ndarray,
model: Sam
) -> sv.Detections:
mask_generator = SamAutomaticMaskGenerator(model)
result = mask_generator.generate(image=image)
return sv.Detections.from_sam(result)
def sam_interactive_inference(
image: np.ndarray,
mask: np.ndarray,
model: Sam
) -> sv.Detections:
predictor = SamPredictor(model)
predictor.set_image(image)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
input_point = polygon[random_point_indexes]
input_label = np.ones(5)
mask = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)[0][0]
masks.append(mask)
masks = np.array(masks, dtype=bool)
return sv.Detections(
xyxy=sv.mask_to_xyxy(masks),
mask=masks
)