typhoon-ocr / handler.py
wealthcoders's picture
Update handler.py
a92f70e verified
from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText, AutoProcessor
from typing import Dict, List, Any
import torch
import base64
import io
from io import BytesIO
from PIL import Image
import os
import tempfile
class EndpointHandler:
def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'):
model_path = model_dir
self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto")
self.processor = AutoProcessor.from_pretrained(model_path)
def __call__(self, data: Dict[str, Any]) -> str:
try:
base64_string = None
if "inputs" in data and isinstance(data["inputs"], str):
base64_string = data["inputs"]
# Case 2: Base64 in nested inputs dictionary
elif "inputs" in data and isinstance(data["inputs"], dict):
base64_string = data["inputs"].get("base64")
# Case 3: Direct base64 at root level
elif "base64" in data:
base64_string = data["base64"]
# Case 4: Try raw data as base64
elif isinstance(data, str):
base64_string = data
if not base64_string:
return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
print("Found base64 string, length:", len(base64_string))
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": "Return content as markdown"
}
],
}
]
# Preparation for inference
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(self.model.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=10000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text[0])
return output_text[0]
except Exception as e:
print(f"Error processing image: {e}")
return str(e)