Create handler.py

#7
by CarlLee - opened
Files changed (1) hide show
  1. handler.py +67 -0
handler.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
3
+
4
+ from PIL import Image
5
+ import torch
6
+ import base64
7
+ from base64 import b64encode
8
+ import requests
9
+ import json
10
+ import io
11
+
12
+ # Take in base64 string and return cv image
13
+ def stringToRGB(base64_string):
14
+ imgdata = base64.b64decode(str(base64_string))
15
+ img = Image.open(io.BytesIO(imgdata)).convert('RGB')
16
+ # opencv_img= cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
17
+ return img
18
+
19
+
20
+ def predict_caption(image_str, max_token = 32):
21
+
22
+ num_beams = 4
23
+ gen_kwargs = {"max_length": max_token, "num_beams": num_beams}
24
+
25
+ images = []
26
+ image = stringToRGB(image_str)
27
+ images.append(image)
28
+
29
+ pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
30
+ pixel_values = pixel_values.to(device)
31
+
32
+ output_ids = model.generate(pixel_values, **gen_kwargs)
33
+
34
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
35
+ preds = [pred.strip() for pred in preds]
36
+ return preds[0]
37
+
38
+
39
+ class EndpointHandler():
40
+ def __init__(self, path=""):
41
+ # Preload all the elements you are going to need at inference.
42
+ # pseudo:
43
+ # self.model= load_model(path
44
+ model = VisionEncoderDecoderModel.from_pretrained(path)
45
+ feature_extractor = ViTFeatureExtractor.from_pretrained(path)
46
+ tokenizer = AutoTokenizer.from_pretrained(path)
47
+
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model = model.to(device)
50
+
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
+ """
54
+ data args:
55
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
56
+ kwargs
57
+ Return:
58
+ A :obj:`list` | `dict`: will be serialized and returned
59
+ """
60
+ max_token = data.pop("max_token", 32)
61
+ img_str = data.pop("data", None)
62
+
63
+ caption = predict_caption(img_str, max_token=max_token)
64
+ return {"caption": f"{caption}"}
65
+
66
+ # pseudo
67
+ # self.model(input)