File size: 2,285 Bytes
e724e19 5f41ad3 e724e19 5f41ad3 e724e19 78cd05e e724e19 5f41ad3 e724e19 5f41ad3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import base64
import gzip
import numpy as np
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
def pack_bits(boolean_tensor):
# Flatten the tensor and add padding if necessary
flat = boolean_tensor.flatten()
if flat.size()[0] % 8 != 0:
padding = np.zeros((8 - flat.size % 8,), dtype=bool)
flat = np.concatenate([flat, padding])
# Reshape into bytes and pack into binary string
packed = np.packbits(flat.reshape((-1, 8)))
packed = packed.tobytes()
return gzip.compress(packed)
# json_str = json.dumps({"shape": boolean_tensor.shape, "data": binary_str})
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(self.device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {"mode": "image"})
# Decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image']))).convert("RGB")
input_points = [inputs['points']] # 2D localization of a window
model_inputs = self.processor(image, input_points=input_points, return_tensors="pt").to(self.device)
outputs = self.model(**model_inputs)
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
model_inputs["original_sizes"].cpu(),
model_inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
packed = [base64.b64encode(pack_bits(masks[0][0][i])).decode() for i in range(masks[0].shape[1])]
shape = list(masks[0].shape)[2:]
return {"masks": packed, "scores": scores[0][0].tolist(), "shape": shape}
|