import os import logging import torch import cv2 import numpy as np from typing import List, Dict, Optional from label_studio_ml.utils import get_image_local_path, InMemoryLRUDictCache logger = logging.getLogger(__name__) VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT") ONNX_CHECKPOINT = os.environ.get("ONNX_CHECKPOINT") MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt") LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN") LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST") class SAMPredictor(object): def __init__(self, model_choice): self.model_choice = model_choice # cache for embeddings # TODO: currently it supports only one image in cache, # since predictor.set_image() should be called each time the new image comes # before making predictions # to extend it to >1 image, we need to store the "active image" state in the cache self.cache = InMemoryLRUDictCache(1) # if you're not using CUDA, use "cpu" instead .... good luck not burning your computer lol self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug(f"Using device {self.device}") if model_choice == 'ONNX': import onnxruntime from segment_anything import sam_model_registry, SamPredictor self.model_checkpoint = VITH_CHECKPOINT if self.model_checkpoint is None: raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint") if ONNX_CHECKPOINT is None: raise FileNotFoundError("ONNX_CHECKPOINT is not set: please set it to the path to the ONNX checkpoint") logger.info(f"Using ONNX checkpoint {ONNX_CHECKPOINT} and SAM checkpoint {self.model_checkpoint}") self.ort = onnxruntime.InferenceSession(ONNX_CHECKPOINT) reg_key = "vit_h" elif model_choice == 'SAM': from segment_anything import SamPredictor, sam_model_registry self.model_checkpoint = VITH_CHECKPOINT if self.model_checkpoint is None: raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint") logger.info(f"Using SAM checkpoint {self.model_checkpoint}") reg_key = "vit_h" elif model_choice == 'MobileSAM': from mobile_sam import SamPredictor, sam_model_registry self.model_checkpoint = MOBILESAM_CHECKPOINT if not self.model_checkpoint: raise FileNotFoundError("MOBILE_CHECKPOINT is not set: please set it to the path to the MobileSAM checkpoint") logger.info(f"Using MobileSAM checkpoint {self.model_checkpoint}") reg_key = 'vit_t' else: raise ValueError(f"Invalid model choice {model_choice}") sam = sam_model_registry[reg_key](checkpoint=self.model_checkpoint) sam.to(device=self.device) self.predictor = SamPredictor(sam) @property def model_name(self): return f'{self.model_choice}:{self.model_checkpoint}:{self.device}' def set_image(self, img_path, calculate_embeddings=True): payload = self.cache.get(img_path) if payload is None: # Get image and embeddings logger.debug(f'Payload not found for {img_path} in `IN_MEM_CACHE`: calculating from scratch') image_path = get_image_local_path( img_path, label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN, label_studio_host=LABEL_STUDIO_HOST ) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.predictor.set_image(image) payload = {'image_shape': image.shape[:2]} logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}') if calculate_embeddings: image_embedding = self.predictor.get_image_embedding().cpu().numpy() payload['image_embedding'] = image_embedding logger.debug(f'Finished storing embeddings for {img_path} in `IN_MEM_CACHE`: ' f'embedding shape {image_embedding.shape}') self.cache.put(img_path, payload) else: logger.debug(f"Using embeddings for {img_path} from `IN_MEM_CACHE`") return payload def predict_onnx( self, img_path, point_coords: Optional[List[List]] = None, point_labels: Optional[List] = None, input_box: Optional[List] = None ): # calculate embeddings payload = self.set_image(img_path, calculate_embeddings=True) image_shape = payload['image_shape'] image_embedding = payload['image_embedding'] onnx_point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None onnx_point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None onnx_box_coords = np.array(input_box, dtype=np.float32).reshape(2, 2) if input_box else None onnx_coords, onnx_labels = None, None if onnx_point_coords is not None and onnx_box_coords is not None: # both keypoints and boxes are present onnx_coords = np.concatenate([onnx_point_coords, onnx_box_coords], axis=0)[None, :, :] onnx_labels = np.concatenate([onnx_point_labels, np.array([2, 3])], axis=0)[None, :].astype(np.float32) elif onnx_point_coords is not None: # only keypoints are present onnx_coords = np.concatenate([onnx_point_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :] onnx_labels = np.concatenate([onnx_point_labels, np.array([-1])], axis=0)[None, :].astype(np.float32) elif onnx_box_coords is not None: # only boxes are present raise NotImplementedError("Boxes without keypoints are not supported yet") onnx_coords = self.predictor.transform.apply_coords(onnx_coords, image_shape).astype(np.float32) # TODO: support mask inputs onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) ort_inputs = { "image_embeddings": image_embedding, "point_coords": onnx_coords, "point_labels": onnx_labels, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array(image_shape, dtype=np.float32) } masks, prob, low_res_logits = self.ort.run(None, ort_inputs) masks = masks > self.predictor.model.mask_threshold mask = masks[0, 0, :, :].astype(np.uint8) # each mask has shape [H, W] prob = float(prob[0][0]) # TODO: support the real multimask output as in https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb return { 'masks': [mask], 'probs': [prob] } def predict_sam( self, img_path, point_coords: Optional[List[List]] = None, point_labels: Optional[List] = None, input_box: Optional[List] = None ): self.set_image(img_path, calculate_embeddings=False) point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None input_box = np.array(input_box, dtype=np.float32) if input_box else None masks, probs, logits = self.predictor.predict( point_coords=point_coords, point_labels=point_labels, box=input_box, # TODO: support multimask output multimask_output=False ) mask = masks[0, :, :].astype(np.uint8) # each mask has shape [H, W] prob = float(probs[0]) return { 'masks': [mask], 'probs': [prob] } def predict( self, img_path: str, point_coords: Optional[List[List]] = None, point_labels: Optional[List] = None, input_box: Optional[List] = None ): if self.model_choice == 'ONNX': return self.predict_onnx(img_path, point_coords, point_labels, input_box) elif self.model_choice in ('SAM', 'MobileSAM'): return self.predict_sam(img_path, point_coords, point_labels, input_box) else: raise NotImplementedError(f"Model choice {self.model_choice} is not supported yet")