gutenocr-endpoint / handler.py
pavun
Refactor EndpointHandler constructor to use model_dir parameter for consistency
ae505d9
import base64
import io
import torch
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, model_dir):
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(
model_dir,
trust_remote_code=True
)
def __call__(self, data):
image_b64 = data["inputs"]["image"]
prompt = data["inputs"]["text"]
image = Image.open(io.BytesIO(base64.b64decode(image_b64)))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, outputs)
]
decoded = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True
)
return {"result": decoded[0]}