nbroad's picture
nbroad HF staff
Upload 3 files
raw history blame
No virus
1.75 kB
import base64
from io import BytesIO
from typing import Dict, List, Any
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from PIL import Image
import torch
class EndpointHandler():
def __init__(self):
model_name = "google/pix2struct-infographics-vqa-large"
self.model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
self.processor = AutoProcessor.from_pretrained(model_name)
self.text_prompt = None #
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
data (:obj:):
includes the input data and the parameters for the inference.
a dictionary with the output of the model. The only key is `output` and the
value is a list of str.
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
if isinstance(inputs["image"], list):
img = [Image.open(BytesIO(base64.b64decode(img))) for img in inputs['image']]
img = Image.open(BytesIO(base64.b64decode(inputs['image'])))
question = inputs['question']
with torch.inference_mode():
model_inputs = self.processor(images=img, text=question, return_tensors="pt").to(self.device)
raw_output = self.model.generate(**model_inputs, **parameters)
decoded_output = self.processor.batch_decode(raw_output, skip_special_tokens=True)
# postprocess the prediction
return {
"output": decoded_output