File size: 5,891 Bytes
fbc8418 acc9b5d 078e469 0afa3fc f40466f acc9b5d 078e469 0afa3fc 078e469 acc9b5d 078e469 0afa3fc 078e469 acc9b5d 078e469 acc9b5d 057b8f0 acc9b5d 078e469 acc9b5d b150b57 acc9b5d b150b57 e4524b0 acc9b5d d9da728 ed47265 f43b9bc ed47265 acc9b5d d9d7db9 e4524b0 acc9b5d 078e469 33ce564 acc9b5d 33ce564 acc9b5d 078e469 acc9b5d e4524b0 acc9b5d e4524b0 acc9b5d 33ce564 acc9b5d 078e469 0afa3fc acc9b5d 078e469 0afa3fc 078e469 33ce564 0afa3fc fbc8418 0afa3fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import Dict, Any
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import traceback # For formatting exception tracebacks
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from moviepy.editor import VideoFileClip
class EndpointHandler():
"""
Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
for multimodal understanding and generation.
"""
def __init__(self, path=""):
"""
Initializes the handler and loads the Qwen2-VL model.
Args:
path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
"""
self.model_dir = path
# Load the Qwen2-VL model
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_dir, torch_dtype="auto", device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(self.model_dir)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes the input data and returns the Qwen2-VL model's output.
Args:
data (Dict[str, Any]): A dictionary containing the input data.
- "inputs" (str): The input text, including image/video references.
- "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
Returns:
Dict[str, Any]: A dictionary containing the generated text.
"""
inputs = data.get("inputs")
max_new_tokens = data.get("max_new_tokens", 128)
# Construct the messages list from the input string
messages = [{"role": "user", "content": self._parse_input(inputs)}]
# Prepare for inference (using qwen_vl_utils)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
# Inference
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}
def _parse_input(self, input_string):
"""
Parses the input string to identify image/video references and text.
Args:
input_string (str): The input string containing text, image, and video references.
Returns:
list: A list of dictionaries representing the parsed content.
"""
content = []
parts = input_string.split("<image>")
for i, part in enumerate(parts):
if i % 2 == 0: # Text part
content.append({"type": "text", "text": part.strip()})
else: # Image/video part
if part.lower().startswith("video:"):
video_path = part.split("video:")[1].strip()
video_frames = self._extract_video_frames(video_path)
if video_frames:
content.append({"type": "video", "video": video_frames, "fps": 1})
else:
image = self._load_image(part.strip())
if image:
content.append({"type": "image", "image": image})
return content
def _load_image(self, image_data):
"""
Loads an image from a URL or base64 encoded string.
Args:
image_data (str): The image data, either a URL or a base64 encoded string.
Returns:
PIL.Image.Image or None: The loaded image, or None if loading fails.
"""
try:
if image_data.startswith("http"):
response = requests.get(image_data, stream=True)
response.raise_for_status() # Check for HTTP errors
return Image.open(response.raw)
elif image_data.startswith("data:image"):
base64_data = image_data.split(",")[1]
image_bytes = base64.b64decode(base64_data)
return Image.open(io.BytesIO(image_bytes))
except requests.exceptions.RequestException as e:
logging.error(f"HTTP error occurred while loading image: {e}")
except IOError as e:
logging.error(f"Error opening image: {e}")
return None
def _extract_video_frames(self, video_path, fps=1):
"""
Extracts frames from a video at the specified FPS using MoviePy.
Args:
video_path (str): The path or URL of the video file.
fps (int, optional): The desired frames per second. Defaults to 1.
Returns:
list or None: A list of PIL Images representing the extracted frames,
or None if extraction fails.
"""
try:
with VideoFileClip(video_path) as video:
return [Image.fromarray(frame.astype('uint8'), 'RGB') for frame in video.iter_frames(fps=fps)]
except Exception as e:
logging.error(f"Error extracting video frames: {e}")
return None
# Additional configurations for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|