nbroad's picture
nbroad HF staff
float16
73353d8
raw
history blame contribute delete
No virus
2.32 kB
import base64
from io import BytesIO
from typing import Dict, List, Any
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from PIL import Image
import torch
class EndpointHandler:
"""
A basic handler for a single GPU in Inference Endpoints.
Should not be used on multiple GPUs or on CPU.
"""
def __init__(self, *args, **kwargs):
model_name = "google/pix2struct-infographics-vqa-large"
"""
dtype tradeoffs:
- float16: works on T4, may have slight worse quality generations
- bfloat16: doesn't work on T4 (works on A10), better quality generation
- float32: works on all GPUs, best quality generation, 30-40% slower
"""
self.dtype = torch.float16
self.model = Pix2StructForConditionalGeneration.from_pretrained(
model_name,
device_map="cuda:0",
torch_dtype=self.dtype,
)
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = torch.device("cuda")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
Can pass a list of images or a single image.
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
a dictionary with the output of the model. The only key is `output` and the
value is a list of str.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
if isinstance(inputs["image"], list):
img = [
Image.open(BytesIO(base64.b64decode(img))) for img in inputs["image"]
]
else:
img = Image.open(BytesIO(base64.b64decode(inputs["image"])))
question = inputs["question"]
with torch.inference_mode():
model_inputs = self.processor(
images=img, text=question, return_tensors="pt"
).to(self.device, dtype=self.dtype)
raw_output = self.model.generate(**model_inputs, **parameters)
decoded_output = self.processor.batch_decode(
raw_output, skip_special_tokens=True
)
# postprocess the prediction
return {"output": decoded_output}