from typing import Dict, List, Any import json import torch from PIL import Image from transformers import Qwen2VLForConditionalGeneration, AutoProcessor class EndpointHandler: def __init__(self, model_name: str ="morthens/qwen2-vl-inference"): # Load the model and processor self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) self.processor = AutoProcessor.from_pretrained(model_name) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Extract image path and messages from the request data image_path = data.get("image_path", "") messages = data.get("messages", []) # Load the image try: image = Image.open(image_path) except FileNotFoundError: return [{"error": "Image file not found."}] except Exception as e: return [{"error": str(e)}] # Prepare the text prompt from messages text_prompt = self.create_text_prompt(messages) # Process inputs for the model inputs = self.processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt" ) # Move inputs to GPU if available inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu") # Generate output using the model output_ids = self.model.generate(**inputs, max_new_tokens=128) # Decode the generated output generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids) ] output_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) # Clean and parse JSON from output text cleaned_data = self.clean_output(output_text[0]) try: json_data = json.loads(cleaned_data) except json.JSONDecodeError: return [{"error": "Failed to parse JSON output."}] return [json_data] def create_text_prompt(self, messages: List[Dict[str, Any]]) -> str: """Extracts and formats text content from messages.""" text_content = "" for message in messages: for content in message.get('content', []): if content['type'] == 'text': text_content += content['text'] return self.processor.apply_chat_template(messages, add_generation_prompt=True) def clean_output(self, output: str) -> str: """Cleans up the model's output for JSON parsing.""" return output.replace("```json\n", "").replace("```", "").strip()