Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
from io import BytesIO | |
from time import perf_counter | |
from typing import Any, List, Optional, Union | |
import numpy as np | |
import onnxruntime | |
import rasterio.features | |
import torch | |
from segment_anything import SamPredictor, sam_model_registry | |
from shapely.geometry import Polygon as ShapelyPolygon | |
from inference.core.entities.requests.inference import InferenceRequestImage | |
from inference.core.entities.requests.sam import ( | |
SamEmbeddingRequest, | |
SamInferenceRequest, | |
SamSegmentationRequest, | |
) | |
from inference.core.entities.responses.sam import ( | |
SamEmbeddingResponse, | |
SamSegmentationResponse, | |
) | |
from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID | |
from inference.core.models.roboflow import RoboflowCoreModel | |
from inference.core.utils.image_utils import load_image_rgb | |
from inference.core.utils.postprocess import masks2poly | |
class SegmentAnything(RoboflowCoreModel): | |
"""SegmentAnything class for handling segmentation tasks. | |
Attributes: | |
sam: The segmentation model. | |
predictor: The predictor for the segmentation model. | |
ort_session: ONNX runtime inference session. | |
embedding_cache: Cache for embeddings. | |
image_size_cache: Cache for image sizes. | |
embedding_cache_keys: Keys for the embedding cache. | |
low_res_logits_cache: Cache for low resolution logits. | |
segmentation_cache_keys: Keys for the segmentation cache. | |
""" | |
def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs): | |
"""Initializes the SegmentAnything. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
super().__init__(*args, model_id=model_id, **kwargs) | |
self.sam = sam_model_registry[self.version_id]( | |
checkpoint=self.cache_file("encoder.pth") | |
) | |
self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu") | |
self.predictor = SamPredictor(self.sam) | |
self.ort_session = onnxruntime.InferenceSession( | |
self.cache_file("decoder.onnx"), | |
providers=[ | |
"CUDAExecutionProvider", | |
"CPUExecutionProvider", | |
], | |
) | |
self.embedding_cache = {} | |
self.image_size_cache = {} | |
self.embedding_cache_keys = [] | |
self.low_res_logits_cache = {} | |
self.segmentation_cache_keys = [] | |
self.task_type = "unsupervised-segmentation" | |
def get_infer_bucket_file_list(self) -> List[str]: | |
"""Gets the list of files required for inference. | |
Returns: | |
List[str]: List of file names. | |
""" | |
return ["encoder.pth", "decoder.onnx"] | |
def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs): | |
""" | |
Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, | |
the cached result will be returned. | |
Args: | |
image (Any): The image to be embedded. The format should be compatible with the preproc_image method. | |
image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached | |
with this ID. Defaults to None. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image | |
and the second element is the shape (height, width) of the processed image. | |
Notes: | |
- Embeddings and image sizes are cached to improve performance on repeated requests for the same image. | |
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, | |
the oldest entries are removed. | |
Example: | |
>>> img_array = ... # some image array | |
>>> embed_image(img_array, image_id="sample123") | |
(array([...]), (224, 224)) | |
""" | |
if image_id and image_id in self.embedding_cache: | |
return ( | |
self.embedding_cache[image_id], | |
self.image_size_cache[image_id], | |
) | |
img_in = self.preproc_image(image) | |
self.predictor.set_image(img_in) | |
embedding = self.predictor.get_image_embedding().cpu().numpy() | |
if image_id: | |
self.embedding_cache[image_id] = embedding | |
self.image_size_cache[image_id] = img_in.shape[:2] | |
self.embedding_cache_keys.append(image_id) | |
if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: | |
cache_key = self.embedding_cache_keys.pop(0) | |
del self.embedding_cache[cache_key] | |
del self.image_size_cache[cache_key] | |
return (embedding, img_in.shape[:2]) | |
def infer_from_request(self, request: SamInferenceRequest): | |
"""Performs inference based on the request type. | |
Args: | |
request (SamInferenceRequest): The inference request. | |
Returns: | |
Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response. | |
""" | |
t1 = perf_counter() | |
if isinstance(request, SamEmbeddingRequest): | |
embedding, _ = self.embed_image(**request.dict()) | |
inference_time = perf_counter() - t1 | |
if request.format == "json": | |
return SamEmbeddingResponse( | |
embeddings=embedding.tolist(), time=inference_time | |
) | |
elif request.format == "binary": | |
binary_vector = BytesIO() | |
np.save(binary_vector, embedding) | |
binary_vector.seek(0) | |
return SamEmbeddingResponse( | |
embeddings=binary_vector.getvalue(), time=inference_time | |
) | |
elif isinstance(request, SamSegmentationRequest): | |
masks, low_res_masks = self.segment_image(**request.dict()) | |
if request.format == "json": | |
masks = masks > self.predictor.model.mask_threshold | |
masks = masks2poly(masks) | |
low_res_masks = low_res_masks > self.predictor.model.mask_threshold | |
low_res_masks = masks2poly(low_res_masks) | |
elif request.format == "binary": | |
binary_vector = BytesIO() | |
np.savez_compressed( | |
binary_vector, masks=masks, low_res_masks=low_res_masks | |
) | |
binary_vector.seek(0) | |
binary_data = binary_vector.getvalue() | |
return binary_data | |
else: | |
raise ValueError(f"Invalid format {request.format}") | |
response = SamSegmentationResponse( | |
masks=[m.tolist() for m in masks], | |
low_res_masks=[m.tolist() for m in low_res_masks], | |
time=perf_counter() - t1, | |
) | |
return response | |
def preproc_image(self, image: InferenceRequestImage): | |
"""Preprocesses an image. | |
Args: | |
image (InferenceRequestImage): The image to preprocess. | |
Returns: | |
np.array: The preprocessed image. | |
""" | |
np_image = load_image_rgb(image) | |
return np_image | |
def segment_image( | |
self, | |
image: Any, | |
embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None, | |
embeddings_format: Optional[str] = "json", | |
has_mask_input: Optional[bool] = False, | |
image_id: Optional[str] = None, | |
mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None, | |
mask_input_format: Optional[str] = "json", | |
orig_im_size: Optional[List[int]] = None, | |
point_coords: Optional[List[List[float]]] = [], | |
point_labels: Optional[List[int]] = [], | |
use_mask_input_cache: Optional[bool] = True, | |
**kwargs, | |
): | |
""" | |
Segments an image based on provided embeddings, points, masks, or cached results. | |
If embeddings are not directly provided, the function can derive them from the input image or cache. | |
Args: | |
image (Any): The image to be segmented. | |
embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image. | |
Defaults to None, in which case the image is used to compute embeddings. | |
embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'. | |
has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False. | |
image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks. | |
mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image. | |
mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'. | |
orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly. | |
point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list. | |
point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list. | |
use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image | |
and the second element is the low resolution segmentation masks. | |
Raises: | |
ValueError: If necessary inputs are missing or inconsistent. | |
Notes: | |
- Embeddings, segmentations, and low-resolution logits can be cached to improve performance | |
on repeated requests for the same image. | |
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, | |
the oldest entries are removed. | |
""" | |
if not embeddings: | |
if not image and not image_id: | |
raise ValueError( | |
"Must provide either image, cached image_id, or embeddings" | |
) | |
elif image_id and not image and image_id not in self.embedding_cache: | |
raise ValueError( | |
f"Image ID {image_id} not in embedding cache, must provide the image or embeddings" | |
) | |
embedding, original_image_size = self.embed_image( | |
image=image, image_id=image_id | |
) | |
else: | |
if not orig_im_size: | |
raise ValueError( | |
"Must provide original image size if providing embeddings" | |
) | |
original_image_size = orig_im_size | |
if embeddings_format == "json": | |
embedding = np.array(embeddings) | |
elif embeddings_format == "binary": | |
embedding = np.load(BytesIO(embeddings)) | |
point_coords = point_coords | |
point_coords.append([0, 0]) | |
point_coords = np.array(point_coords, dtype=np.float32) | |
point_coords = np.expand_dims(point_coords, axis=0) | |
point_coords = self.predictor.transform.apply_coords( | |
point_coords, | |
original_image_size, | |
) | |
point_labels = point_labels | |
point_labels.append(-1) | |
point_labels = np.array(point_labels, dtype=np.float32) | |
point_labels = np.expand_dims(point_labels, axis=0) | |
if has_mask_input: | |
if ( | |
image_id | |
and image_id in self.low_res_logits_cache | |
and use_mask_input_cache | |
): | |
mask_input = self.low_res_logits_cache[image_id] | |
elif not mask_input and ( | |
not image_id or image_id not in self.low_res_logits_cache | |
): | |
raise ValueError("Must provide either mask_input or cached image_id") | |
else: | |
if mask_input_format == "json": | |
polys = mask_input | |
mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8) | |
for i, poly in enumerate(polys): | |
poly = ShapelyPolygon(poly) | |
raster = rasterio.features.rasterize( | |
[poly], out_shape=(256, 256) | |
) | |
mask_input[0, i, :, :] = raster | |
elif mask_input_format == "binary": | |
binary_data = base64.b64decode(mask_input) | |
mask_input = np.load(BytesIO(binary_data)) | |
else: | |
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
ort_inputs = { | |
"image_embeddings": embedding.astype(np.float32), | |
"point_coords": point_coords.astype(np.float32), | |
"point_labels": point_labels, | |
"mask_input": mask_input.astype(np.float32), | |
"has_mask_input": ( | |
np.zeros(1, dtype=np.float32) | |
if not has_mask_input | |
else np.ones(1, dtype=np.float32) | |
), | |
"orig_im_size": np.array(original_image_size, dtype=np.float32), | |
} | |
masks, _, low_res_logits = self.ort_session.run(None, ort_inputs) | |
if image_id: | |
self.low_res_logits_cache[image_id] = low_res_logits | |
if image_id not in self.segmentation_cache_keys: | |
self.segmentation_cache_keys.append(image_id) | |
if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: | |
cache_key = self.segmentation_cache_keys.pop(0) | |
del self.low_res_logits_cache[cache_key] | |
masks = masks[0] | |
low_res_masks = low_res_logits[0] | |
return masks, low_res_masks | |