File size: 5,891 Bytes
fbc8418
acc9b5d
 
 
 
 
 
 
078e469
0afa3fc
 
 
f40466f
acc9b5d
078e469
 
 
0afa3fc
078e469
 
acc9b5d
078e469
0afa3fc
078e469
 
 
acc9b5d
078e469
 
acc9b5d
 
 
 
057b8f0
acc9b5d
 
078e469
 
 
 
 
 
 
acc9b5d
 
 
b150b57
acc9b5d
 
b150b57
e4524b0
acc9b5d
 
d9da728
ed47265
f43b9bc
ed47265
 
 
 
 
acc9b5d
 
 
d9d7db9
e4524b0
acc9b5d
 
 
 
 
 
078e469
33ce564
acc9b5d
33ce564
acc9b5d
078e469
 
 
 
 
 
 
acc9b5d
 
 
 
 
 
e4524b0
acc9b5d
 
 
e4524b0
acc9b5d
 
 
 
 
33ce564
acc9b5d
078e469
 
 
 
 
 
 
0afa3fc
 
 
 
 
 
 
 
 
 
 
 
 
 
acc9b5d
 
078e469
 
 
 
 
 
0afa3fc
078e469
 
33ce564
0afa3fc
 
fbc8418
0afa3fc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import Dict, Any
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import traceback  # For formatting exception tracebacks
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from moviepy.editor import VideoFileClip

class EndpointHandler():
    """
    Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
    This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
    for multimodal understanding and generation.
    """

    def __init__(self, path=""):
        """
        Initializes the handler and loads the Qwen2-VL model.
        Args:
            path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
        """
        self.model_dir = path

        # Load the Qwen2-VL model
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_dir, torch_dtype="auto", device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes the input data and returns the Qwen2-VL model's output.
        Args:
            data (Dict[str, Any]): A dictionary containing the input data.
                - "inputs" (str): The input text, including image/video references.
                - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
        Returns:
            Dict[str, Any]: A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)

        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]

        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return {"generated_text": output_text}

    def _parse_input(self, input_string):
        """
        Parses the input string to identify image/video references and text.
        Args:
            input_string (str): The input string containing text, image, and video references.
        Returns:
            list: A list of dictionaries representing the parsed content.
        """
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i % 2 == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image/video part
                if part.lower().startswith("video:"):
                    video_path = part.split("video:")[1].strip()
                    video_frames = self._extract_video_frames(video_path)
                    if video_frames:
                        content.append({"type": "video", "video": video_frames, "fps": 1})
                else:
                    image = self._load_image(part.strip())
                    if image:
                        content.append({"type": "image", "image": image})
        return content

    def _load_image(self, image_data):
        """
        Loads an image from a URL or base64 encoded string.
        Args:
            image_data (str): The image data, either a URL or a base64 encoded string.
        Returns:
            PIL.Image.Image or None: The loaded image, or None if loading fails.
        """
        try:
            if image_data.startswith("http"):
                response = requests.get(image_data, stream=True)
                response.raise_for_status()  # Check for HTTP errors
                return Image.open(response.raw)
            elif image_data.startswith("data:image"):
                base64_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(base64_data)
                return Image.open(io.BytesIO(image_bytes))
        except requests.exceptions.RequestException as e:
            logging.error(f"HTTP error occurred while loading image: {e}")
        except IOError as e:
            logging.error(f"Error opening image: {e}")
        return None

    def _extract_video_frames(self, video_path, fps=1):
        """
        Extracts frames from a video at the specified FPS using MoviePy.
        Args:
            video_path (str): The path or URL of the video file.
            fps (int, optional): The desired frames per second. Defaults to 1.
        Returns:
            list or None: A list of PIL Images representing the extracted frames,
                          or None if extraction fails.
        """
        try:
            with VideoFileClip(video_path) as video:
                return [Image.fromarray(frame.astype('uint8'), 'RGB') for frame in video.iter_frames(fps=fps)]
        except Exception as e:
            logging.error(f"Error extracting video frames: {e}")
            return None

# Additional configurations for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')