Files changed (1) hide show
  1. handler.py +24 -0
handler.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # load model and processor from path
8
+ self.processor = AutoTokenizer.from_pretrained(path)
9
+ self.model = ViltForQuestionAnswering.from_pretrained(path)
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
13
+ # process input
14
+ image = data.pop("image", data)
15
+ text = data.pop("text", data)
16
+ parameters = data.pop("parameters", None)
17
+
18
+ # preprocess
19
+ encoding = processor(image, text, return_tensors="pt")
20
+ outputs = model(**encoding)
21
+ # postprocess the prediction
22
+ logits = outputs.logits
23
+ idx = logits.argmax(-1).item()
24
+ return [{"answer": model.config.id2label[idx]}]