mattmdjaga commited on
Commit
04de101
1 Parent(s): e69a11e

Added custom handler

Browse files
Files changed (1) hide show
  1. handler.py +39 -0
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from transformers import AutoModelForSemanticSegmentation, AutoFeatureExtractor
5
+ import base64
6
+ import torch
7
+ from torch import nn
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path="."):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.model = AutoModelForSemanticSegmentation.from_pretrained(path).to(self.device).eval()
13
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ images (:obj:`PIL.Image`)
19
+ candiates (:obj:`list`)
20
+ Return:
21
+ A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
22
+ """
23
+ inputs = data.pop("inputs", data)
24
+
25
+ # decode base64 image to PIL
26
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
27
+
28
+ # preprocess image
29
+ encoding = self.feature_extractor(images=image, return_tensors="pt")
30
+ pixel_values = encoding["pixel_values"].to(self.device)
31
+ with torch.no_grad():
32
+ outputs = self.model(pixel_values=pixel_values)
33
+ logits = outputs.logits
34
+ upsampled_logits = nn.functional.interpolate(logits,
35
+ size=image.size[::-1],
36
+ mode="bilinear",
37
+ align_corners=False,)
38
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
39
+ return pred_seg.tolist()