OMG / inference /models /sam /segment_anything.py
Fucius's picture
Upload 422 files
df6c67d verified
import base64
from io import BytesIO
from time import perf_counter
from typing import Any, List, Optional, Union
import numpy as np
import onnxruntime
import rasterio.features
import torch
from segment_anything import SamPredictor, sam_model_registry
from shapely.geometry import Polygon as ShapelyPolygon
from inference.core.entities.requests.inference import InferenceRequestImage
from inference.core.entities.requests.sam import (
SamEmbeddingRequest,
SamInferenceRequest,
SamSegmentationRequest,
)
from inference.core.entities.responses.sam import (
SamEmbeddingResponse,
SamSegmentationResponse,
)
from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb
from inference.core.utils.postprocess import masks2poly
class SegmentAnything(RoboflowCoreModel):
"""SegmentAnything class for handling segmentation tasks.
Attributes:
sam: The segmentation model.
predictor: The predictor for the segmentation model.
ort_session: ONNX runtime inference session.
embedding_cache: Cache for embeddings.
image_size_cache: Cache for image sizes.
embedding_cache_keys: Keys for the embedding cache.
low_res_logits_cache: Cache for low resolution logits.
segmentation_cache_keys: Keys for the segmentation cache.
"""
def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs):
"""Initializes the SegmentAnything.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, model_id=model_id, **kwargs)
self.sam = sam_model_registry[self.version_id](
checkpoint=self.cache_file("encoder.pth")
)
self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu")
self.predictor = SamPredictor(self.sam)
self.ort_session = onnxruntime.InferenceSession(
self.cache_file("decoder.onnx"),
providers=[
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
self.embedding_cache = {}
self.image_size_cache = {}
self.embedding_cache_keys = []
self.low_res_logits_cache = {}
self.segmentation_cache_keys = []
self.task_type = "unsupervised-segmentation"
def get_infer_bucket_file_list(self) -> List[str]:
"""Gets the list of files required for inference.
Returns:
List[str]: List of file names.
"""
return ["encoder.pth", "decoder.onnx"]
def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs):
"""
Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached,
the cached result will be returned.
Args:
image (Any): The image to be embedded. The format should be compatible with the preproc_image method.
image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached
with this ID. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image
and the second element is the shape (height, width) of the processed image.
Notes:
- Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
the oldest entries are removed.
Example:
>>> img_array = ... # some image array
>>> embed_image(img_array, image_id="sample123")
(array([...]), (224, 224))
"""
if image_id and image_id in self.embedding_cache:
return (
self.embedding_cache[image_id],
self.image_size_cache[image_id],
)
img_in = self.preproc_image(image)
self.predictor.set_image(img_in)
embedding = self.predictor.get_image_embedding().cpu().numpy()
if image_id:
self.embedding_cache[image_id] = embedding
self.image_size_cache[image_id] = img_in.shape[:2]
self.embedding_cache_keys.append(image_id)
if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
cache_key = self.embedding_cache_keys.pop(0)
del self.embedding_cache[cache_key]
del self.image_size_cache[cache_key]
return (embedding, img_in.shape[:2])
def infer_from_request(self, request: SamInferenceRequest):
"""Performs inference based on the request type.
Args:
request (SamInferenceRequest): The inference request.
Returns:
Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response.
"""
t1 = perf_counter()
if isinstance(request, SamEmbeddingRequest):
embedding, _ = self.embed_image(**request.dict())
inference_time = perf_counter() - t1
if request.format == "json":
return SamEmbeddingResponse(
embeddings=embedding.tolist(), time=inference_time
)
elif request.format == "binary":
binary_vector = BytesIO()
np.save(binary_vector, embedding)
binary_vector.seek(0)
return SamEmbeddingResponse(
embeddings=binary_vector.getvalue(), time=inference_time
)
elif isinstance(request, SamSegmentationRequest):
masks, low_res_masks = self.segment_image(**request.dict())
if request.format == "json":
masks = masks > self.predictor.model.mask_threshold
masks = masks2poly(masks)
low_res_masks = low_res_masks > self.predictor.model.mask_threshold
low_res_masks = masks2poly(low_res_masks)
elif request.format == "binary":
binary_vector = BytesIO()
np.savez_compressed(
binary_vector, masks=masks, low_res_masks=low_res_masks
)
binary_vector.seek(0)
binary_data = binary_vector.getvalue()
return binary_data
else:
raise ValueError(f"Invalid format {request.format}")
response = SamSegmentationResponse(
masks=[m.tolist() for m in masks],
low_res_masks=[m.tolist() for m in low_res_masks],
time=perf_counter() - t1,
)
return response
def preproc_image(self, image: InferenceRequestImage):
"""Preprocesses an image.
Args:
image (InferenceRequestImage): The image to preprocess.
Returns:
np.array: The preprocessed image.
"""
np_image = load_image_rgb(image)
return np_image
def segment_image(
self,
image: Any,
embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None,
embeddings_format: Optional[str] = "json",
has_mask_input: Optional[bool] = False,
image_id: Optional[str] = None,
mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
mask_input_format: Optional[str] = "json",
orig_im_size: Optional[List[int]] = None,
point_coords: Optional[List[List[float]]] = [],
point_labels: Optional[List[int]] = [],
use_mask_input_cache: Optional[bool] = True,
**kwargs,
):
"""
Segments an image based on provided embeddings, points, masks, or cached results.
If embeddings are not directly provided, the function can derive them from the input image or cache.
Args:
image (Any): The image to be segmented.
embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image.
Defaults to None, in which case the image is used to compute embeddings.
embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'.
has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False.
image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image.
mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'.
orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly.
point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list.
point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list.
use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True.
**kwargs: Additional keyword arguments.
Returns:
Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image
and the second element is the low resolution segmentation masks.
Raises:
ValueError: If necessary inputs are missing or inconsistent.
Notes:
- Embeddings, segmentations, and low-resolution logits can be cached to improve performance
on repeated requests for the same image.
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
the oldest entries are removed.
"""
if not embeddings:
if not image and not image_id:
raise ValueError(
"Must provide either image, cached image_id, or embeddings"
)
elif image_id and not image and image_id not in self.embedding_cache:
raise ValueError(
f"Image ID {image_id} not in embedding cache, must provide the image or embeddings"
)
embedding, original_image_size = self.embed_image(
image=image, image_id=image_id
)
else:
if not orig_im_size:
raise ValueError(
"Must provide original image size if providing embeddings"
)
original_image_size = orig_im_size
if embeddings_format == "json":
embedding = np.array(embeddings)
elif embeddings_format == "binary":
embedding = np.load(BytesIO(embeddings))
point_coords = point_coords
point_coords.append([0, 0])
point_coords = np.array(point_coords, dtype=np.float32)
point_coords = np.expand_dims(point_coords, axis=0)
point_coords = self.predictor.transform.apply_coords(
point_coords,
original_image_size,
)
point_labels = point_labels
point_labels.append(-1)
point_labels = np.array(point_labels, dtype=np.float32)
point_labels = np.expand_dims(point_labels, axis=0)
if has_mask_input:
if (
image_id
and image_id in self.low_res_logits_cache
and use_mask_input_cache
):
mask_input = self.low_res_logits_cache[image_id]
elif not mask_input and (
not image_id or image_id not in self.low_res_logits_cache
):
raise ValueError("Must provide either mask_input or cached image_id")
else:
if mask_input_format == "json":
polys = mask_input
mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8)
for i, poly in enumerate(polys):
poly = ShapelyPolygon(poly)
raster = rasterio.features.rasterize(
[poly], out_shape=(256, 256)
)
mask_input[0, i, :, :] = raster
elif mask_input_format == "binary":
binary_data = base64.b64decode(mask_input)
mask_input = np.load(BytesIO(binary_data))
else:
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
ort_inputs = {
"image_embeddings": embedding.astype(np.float32),
"point_coords": point_coords.astype(np.float32),
"point_labels": point_labels,
"mask_input": mask_input.astype(np.float32),
"has_mask_input": (
np.zeros(1, dtype=np.float32)
if not has_mask_input
else np.ones(1, dtype=np.float32)
),
"orig_im_size": np.array(original_image_size, dtype=np.float32),
}
masks, _, low_res_logits = self.ort_session.run(None, ort_inputs)
if image_id:
self.low_res_logits_cache[image_id] = low_res_logits
if image_id not in self.segmentation_cache_keys:
self.segmentation_cache_keys.append(image_id)
if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
cache_key = self.segmentation_cache_keys.pop(0)
del self.low_res_logits_cache[cache_key]
masks = masks[0]
low_res_masks = low_res_logits[0]
return masks, low_res_masks