File size: 901 Bytes
3d19edf
 
253f1f9
 
 
 
 
 
 
176c89c
253f1f9
 
 
 
 
781701c
 
 
253f1f9
 
402c710
 
253f1f9
 
 
402c710
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

from typing import Any, Dict
from transformers import ViltProcessor, ViltForQuestionAnswering


class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.processor = ViltProcessor.from_pretrained(path)
        self.model = ViltForQuestionAnswering.from_pretrained(path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        # process input
        inputs = data.pop("inputs", data)
        image = inputs["image"]
        text = inputs["text"]

        # preprocess
        encoding = self.processor(image, text, return_tensors="pt")
        outputs = self.model(**encoding)
        # postprocess the prediction
        logits = outputs.logits
        idx = logits.argmax(-1).item()
        return [{"answer": self.model.config.id2label[idx]}]