File size: 2,285 Bytes
e724e19
5f41ad3
 
e724e19
 
 
 
 
 
5f41ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
e724e19
 
 
 
78cd05e
 
e724e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f41ad3
 
 
 
e724e19
 
5f41ad3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import base64
import gzip
import numpy as np
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
import torch
from transformers import SamModel, SamProcessor

def pack_bits(boolean_tensor):
    # Flatten the tensor and add padding if necessary
    flat = boolean_tensor.flatten()
    if flat.size()[0] % 8 != 0:
        padding = np.zeros((8 - flat.size % 8,), dtype=bool)
        flat = np.concatenate([flat, padding])
    
    # Reshape into bytes and pack into binary string
    packed = np.packbits(flat.reshape((-1, 8)))
    packed = packed.tobytes()
    return gzip.compress(packed)
# json_str = json.dumps({"shape": boolean_tensor.shape, "data": binary_str})


class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(self.device)
        self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {"mode": "image"})

        # Decode base64 image to PIL
        image = Image.open(BytesIO(base64.b64decode(inputs['image']))).convert("RGB")
        input_points = [inputs['points']] # 2D localization of a window

        model_inputs = self.processor(image, input_points=input_points, return_tensors="pt").to(self.device)
        outputs = self.model(**model_inputs)
        masks = self.processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), 
            model_inputs["original_sizes"].cpu(), 
            model_inputs["reshaped_input_sizes"].cpu())
        scores = outputs.iou_scores

        packed = [base64.b64encode(pack_bits(masks[0][0][i])).decode() for i in range(masks[0].shape[1])]
        shape = list(masks[0].shape)[2:]

        return {"masks": packed, "scores": scores[0][0].tolist(), "shape": shape}