File size: 1,515 Bytes
a167d24
 
 
 
 
 
3c486b2
 
b789647
 
a167d24
50e3316
 
a167d24
 
 
 
 
 
 
50e3316
a167d24
 
3c486b2
 
 
62f1f79
 
 
 
3c486b2
 
a167d24
 
50e3316
a167d24
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import VitMatteImageProcessor, VitMatteForImageMatting
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
import torchvision.transforms as T
from typing import Dict, List, Any
from io import BytesIO
import base64
# image = Image.open("man.png").convert("RGB")
# trimap = Image.open("mask2.png").convert("L")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class EndpointHandler():
    def __init__(self, path=""):
        self.processor = VitMatteImageProcessor.from_pretrained(
            "hustvl/vitmatte-small-composition-1k")
        self.model = VitMatteForImageMatting.from_pretrained(
            "hustvl/vitmatte-small-composition-1k")
        self.model = self.model.to(device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.pop("inputs", data)
        # parameters = data.pop("parameters", {"mode": "image"})

        image = Image.open(
            BytesIO(base64.b64decode(inputs['image']))).convert("RGB")
        trimap = Image.open(
            BytesIO(base64.b64decode(inputs['trimap']))).convert("L")
        # image = data.pop("image")
        # trimap = data.pop("trimap")

        inputs = self.processor(
            images=image, trimaps=trimap, return_tensors="pt").to(device)

        with torch.no_grad():
            alphas = self.model(**inputs).alphas

        print(alphas.shape)
        image = T.ToPILImage()(torch.squeeze(alphas))

        return {"result": image}