Gabriel's picture
Create handler.py
e706c52 verified
raw
history blame
2.81 kB
from typing import Dict, Any
from transformers import QwenImageProcessor, QwenTokenizer, QwenForMultiModalConditionalGeneration
import torch
from PIL import Image
import io
import json
import base64
import requests
class EndpointHandler():
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = QwenForMultiModalConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
).to(self.device)
self.image_processor = QwenImageProcessor.from_pretrained(path)
self.tokenizer = QwenTokenizer.from_pretrained(path)
self.model.generation_config.use_cache = False
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
data (Any): The input data, which can be:
- Binary image data in the request body.
- A dictionary with 'image' and 'text' keys:
- 'image': Base64-encoded image string or image URL.
- 'text': The text prompt.
Returns:
Dict[str, Any]: The generated text output from the model.
"""
if isinstance(data, (bytes, bytearray)):
image = Image.open(io.BytesIO(data)).convert('RGB')
text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
elif isinstance(data, dict):
image_input = data.get('image', None)
text_input = data.get('text', '')
if image_input is None:
return {"error": "No image provided."}
if image_input.startswith('http'):
response = requests.get(image_input)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
else:
image_data = base64.b64decode(image_input)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}
image_inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
if not text_input:
text_input = "<|im_start|>user\nDescribe this image.\n<|im_end|><|im_start|>assistant\n"
input_ids = self.tokenizer(text_input, return_tensors="pt").input_ids.to(self.device)
generated_ids = self.model.generate(
**image_inputs,
input_ids=input_ids,
max_new_tokens=256,
do_sample=True,
top_p=0.9,
temperature=0.7,
)
output_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return {"generated_text": output_text}