Jeney's picture
Add all answers in output
2e2a941
raw
history blame contribute delete
No virus
1.38 kB
import torch
import io
from typing import Any, Dict
from PIL import Image
from transformers import ViltProcessor, ViltForQuestionAnswering
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = ViltProcessor.from_pretrained(path)
self.model = ViltForQuestionAnswering.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# process input
inputs = data.pop("inputs", data)
image = inputs["image"]
image = Image.open(io.BytesIO(eval(image)))
text = inputs["text"]
# preprocess
encoding = self.processor(image, text, return_tensors="pt")
outputs = self.model(**encoding)
# postprocess the prediction
logits = outputs.logits
best_idx = logits.argmax(-1).item()
best_answer = self.model.config.id2label[best_idx]
probabilities = torch.softmax(logits, dim=-1)[0]
id2label = self.model.config.id2label
answers = []
for idx, prob in enumerate(probabilities):
answer = id2label[idx]
answer_score = float(prob)
answers.append({"answer": answer, "answer_score": answer_score})
return {"best_answer": best_answer, "answers": answers}