File size: 1,383 Bytes
3d19edf
1de5e91
3d19edf
253f1f9
1de5e91
253f1f9
 
 
 
 
 
176c89c
253f1f9
 
 
2e2a941
253f1f9
781701c
 
1de5e91
781701c
2e2a941
253f1f9
402c710
 
253f1f9
 
2e2a941
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import io

from typing import Any, Dict
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering


class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.processor = ViltProcessor.from_pretrained(path)
        self.model = ViltForQuestionAnswering.from_pretrained(path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # process input
        inputs = data.pop("inputs", data)
        image = inputs["image"]
        image = Image.open(io.BytesIO(eval(image)))
        text = inputs["text"]
    
        # preprocess
        encoding = self.processor(image, text, return_tensors="pt")
        outputs = self.model(**encoding)
        # postprocess the prediction
        logits = outputs.logits
        best_idx = logits.argmax(-1).item()
        best_answer = self.model.config.id2label[best_idx]
        probabilities = torch.softmax(logits, dim=-1)[0]
        id2label = self.model.config.id2label
        answers = []
        for idx, prob in enumerate(probabilities):
            answer = id2label[idx]
            answer_score = float(prob)
            answers.append({"answer": answer, "answer_score": answer_score})
        return {"best_answer": best_answer, "answers": answers}