File size: 2,316 Bytes
2cdc125 73353d8 2cdc125 73353d8 2cdc125 f1b11f2 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 2cdc125 73353d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import base64
from io import BytesIO
from typing import Dict, List, Any
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from PIL import Image
import torch
class EndpointHandler:
"""
A basic handler for a single GPU in Inference Endpoints.
Should not be used on multiple GPUs or on CPU.
"""
def __init__(self, *args, **kwargs):
model_name = "google/pix2struct-infographics-vqa-large"
"""
dtype tradeoffs:
- float16: works on T4, may have slight worse quality generations
- bfloat16: doesn't work on T4 (works on A10), better quality generation
- float32: works on all GPUs, best quality generation, 30-40% slower
"""
self.dtype = torch.float16
self.model = Pix2StructForConditionalGeneration.from_pretrained(
model_name,
device_map="cuda:0",
torch_dtype=self.dtype,
)
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = torch.device("cuda")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
Can pass a list of images or a single image.
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
a dictionary with the output of the model. The only key is `output` and the
value is a list of str.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
if isinstance(inputs["image"], list):
img = [
Image.open(BytesIO(base64.b64decode(img))) for img in inputs["image"]
]
else:
img = Image.open(BytesIO(base64.b64decode(inputs["image"])))
question = inputs["question"]
with torch.inference_mode():
model_inputs = self.processor(
images=img, text=question, return_tensors="pt"
).to(self.device, dtype=self.dtype)
raw_output = self.model.generate(**model_inputs, **parameters)
decoded_output = self.processor.batch_decode(
raw_output, skip_special_tokens=True
)
# postprocess the prediction
return {"output": decoded_output}
|