|
|
from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText, AutoProcessor |
|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'): |
|
|
model_path = model_dir |
|
|
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto") |
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
|
try: |
|
|
base64_string = None |
|
|
if "inputs" in data and isinstance(data["inputs"], str): |
|
|
base64_string = data["inputs"] |
|
|
|
|
|
|
|
|
elif "inputs" in data and isinstance(data["inputs"], dict): |
|
|
base64_string = data["inputs"].get("base64") |
|
|
|
|
|
|
|
|
elif "base64" in data: |
|
|
base64_string = data["base64"] |
|
|
|
|
|
|
|
|
elif isinstance(data, str): |
|
|
base64_string = data |
|
|
|
|
|
if not base64_string: |
|
|
return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())} |
|
|
|
|
|
print("Found base64 string, length:", len(base64_string)) |
|
|
|
|
|
|
|
|
if ',' in base64_string: |
|
|
base64_string = base64_string.split(',')[1] |
|
|
|
|
|
|
|
|
image_data = base64.b64decode(base64_string) |
|
|
|
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": "Return content as markdown" |
|
|
} |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = inputs.to(self.model.device) |
|
|
|
|
|
|
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=10000) |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
output_text = self.processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
print(output_text[0]) |
|
|
|
|
|
return output_text[0] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing image: {e}") |
|
|
return str(e) |