Qwen2-VL-7B-Instruct / handler.py
hperkins's picture
Update handler.py
0afa3fc verified
raw
history blame
5.89 kB
from typing import Dict, Any
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import traceback # For formatting exception tracebacks
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from moviepy.editor import VideoFileClip
class EndpointHandler():
"""
Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
for multimodal understanding and generation.
"""
def __init__(self, path=""):
"""
Initializes the handler and loads the Qwen2-VL model.
Args:
path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
"""
self.model_dir = path
# Load the Qwen2-VL model
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_dir, torch_dtype="auto", device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(self.model_dir)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes the input data and returns the Qwen2-VL model's output.
Args:
data (Dict[str, Any]): A dictionary containing the input data.
- "inputs" (str): The input text, including image/video references.
- "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
Returns:
Dict[str, Any]: A dictionary containing the generated text.
"""
inputs = data.get("inputs")
max_new_tokens = data.get("max_new_tokens", 128)
# Construct the messages list from the input string
messages = [{"role": "user", "content": self._parse_input(inputs)}]
# Prepare for inference (using qwen_vl_utils)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
# Inference
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}
def _parse_input(self, input_string):
"""
Parses the input string to identify image/video references and text.
Args:
input_string (str): The input string containing text, image, and video references.
Returns:
list: A list of dictionaries representing the parsed content.
"""
content = []
parts = input_string.split("<image>")
for i, part in enumerate(parts):
if i % 2 == 0: # Text part
content.append({"type": "text", "text": part.strip()})
else: # Image/video part
if part.lower().startswith("video:"):
video_path = part.split("video:")[1].strip()
video_frames = self._extract_video_frames(video_path)
if video_frames:
content.append({"type": "video", "video": video_frames, "fps": 1})
else:
image = self._load_image(part.strip())
if image:
content.append({"type": "image", "image": image})
return content
def _load_image(self, image_data):
"""
Loads an image from a URL or base64 encoded string.
Args:
image_data (str): The image data, either a URL or a base64 encoded string.
Returns:
PIL.Image.Image or None: The loaded image, or None if loading fails.
"""
try:
if image_data.startswith("http"):
response = requests.get(image_data, stream=True)
response.raise_for_status() # Check for HTTP errors
return Image.open(response.raw)
elif image_data.startswith("data:image"):
base64_data = image_data.split(",")[1]
image_bytes = base64.b64decode(base64_data)
return Image.open(io.BytesIO(image_bytes))
except requests.exceptions.RequestException as e:
logging.error(f"HTTP error occurred while loading image: {e}")
except IOError as e:
logging.error(f"Error opening image: {e}")
return None
def _extract_video_frames(self, video_path, fps=1):
"""
Extracts frames from a video at the specified FPS using MoviePy.
Args:
video_path (str): The path or URL of the video file.
fps (int, optional): The desired frames per second. Defaults to 1.
Returns:
list or None: A list of PIL Images representing the extracted frames,
or None if extraction fails.
"""
try:
with VideoFileClip(video_path) as video:
return [Image.fromarray(frame.astype('uint8'), 'RGB') for frame in video.iter_frames(fps=fps)]
except Exception as e:
logging.error(f"Error extracting video frames: {e}")
return None
# Additional configurations for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')