test-sam-handler / handler.py
Gilad Avidan
initial commit
e724e19
raw history blame
No virus
1.53 kB
import base64
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SamModel.from_pretrained("facebook/sam-vit-huge").to(self.device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {"mode": "image"})
# Decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image']))).convert("RGB")
input_points = [inputs['points']] # 2D localization of a window
model_inputs = self.processor(image, input_points=input_points, return_tensors="pt").to(self.device)
outputs = self.model(**model_inputs)
masks = self.processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), model_inputs["original_sizes"].cpu(), model_inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
return {"masks": masks, "scores": scores}