| from typing import Union, Optional, List, Dict, Tuple, Callable |
| from transformers.processing_utils import (ProcessorMixin, |
| VideosKwargs, |
| AudioKwargs, |
| ImagesKwargs, |
| TextKwargs, |
| ProcessingKwargs, |
| Unpack) |
| import numpy as np |
| import decord |
| import torch |
| import PIL |
| from transformers.audio_utils import load_audio |
| from transformers.image_utils import load_image, load_video |
| from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoTokenizer |
|
|
|
|
| def load_audio_str(audio_path_or_url: str, sampling_rate: int = 16000) -> np.ndarray: |
| audio = load_audio(audio_path_or_url, sampling_rate=sampling_rate) |
| return audio |
|
|
|
|
| def load_video_str(video_path_or_url: str, num_frames: int = 4, fps: int = None) -> List[np.ndarray]: |
| video = load_video(video_path_or_url, num_frames=num_frames, fps=fps, |
| backend="decord") |
| return video |
|
|
|
|
| def load_image_str(image_path_or_url: str) -> List[np.ndarray]: |
| image = load_image(image_path_or_url) |
| return image |
|
|
|
|
| ImageInput = Union[ |
| |
| "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"], |
| |
| str, list[str] |
| ] |
|
|
|
|
| VideoInput = Union[ |
| |
| list["PIL.Image.Image"], "np.ndarray", "torch.Tensor", list["np.ndarray"], |
| list["torch.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarray"]], |
| list[list["torch.Tensor"]], |
| |
| str, list[str], list[list[str]] |
| ] |
|
|
|
|
| AudioInput = Union[ |
| |
| np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"], |
| |
| str, list[str] |
| ] |
|
|
|
|
| class QualityvImageKwargs(ImagesKwargs): |
| tokens_per_image: int = 197 |
|
|
|
|
| class QualityvVideoKwargs(VideosKwargs): |
| num_frames: Union[int, None] = 4 |
| fps: Union[int, None] = None |
| tokens_per_frame: int = 197 |
|
|
|
|
| class QualityvAudioKwargs(AudioKwargs): |
| sampling_rate: Union[int, None] = 16000 |
| tokens_per_audio: int = 1500 |
| |
| |
| class QualityvProcessingKwargs(ProcessingKwargs): |
| images_kwargs: QualityvImageKwargs |
| videos_kwargs: QualityvVideoKwargs |
| audio_kwargs: QualityvAudioKwargs |
| text_kwargs: TextKwargs |
| |
|
|
| class QualityvProcessor(ProcessorMixin): |
| |
| attributes = ["image_processor", |
| "audio_processor", |
| "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| audio_processor_class = "AutoFeatureExtractor" |
| tokenizer_class = "AutoTokenizer" |
| |
| chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system |
| You are a helpful assistant.<|im_end|> |
| {% endif %}<|im_start|>{{ message['role'] }} |
| {% if message['content'] is string %}{{ message['content'] }}<|im_end|> |
| {% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_vision_id %}Audio {{ audio_count.value }}: {% endif %}<|vision_start|><|audio_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> |
| {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant |
| {% endif %}""" |
|
|
| |
| def __init__(self, tokenizer=None, |
| image_processor=None, |
| audio_processor=None, |
| chat_template=None, |
| image_token="<|image_pad|>", |
| video_token="<|video_pad|>", |
| audio_token="<|audio_pad|>", |
| label_start_text="<|im_start|>assistant\n", |
| label_end_text="<|im_end|>\n", |
| **kwargs): |
| self.image_token = image_token if not hasattr(tokenizer, "image_token") else tokenizer.image_token |
| self.video_token = video_token if not hasattr(tokenizer, "video_token") else tokenizer.video_token |
| self.audio_token = audio_token if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token |
| self.label_start_text = label_start_text |
| self.label_end_text = label_end_text |
| self.image_token_id = ( |
| tokenizer.image_token_id |
| if getattr(tokenizer, "image_token_id", None) |
| else tokenizer.convert_tokens_to_ids(self.image_token) |
| ) |
| self.video_token_id = ( |
| tokenizer.video_token_id |
| if getattr(tokenizer, "video_token_id", None) |
| else tokenizer.convert_tokens_to_ids(self.video_token) |
| ) |
| self.audio_token_id = ( |
| tokenizer.audio_token_id |
| if getattr(tokenizer, "audio_token_id", None) |
| else tokenizer.convert_tokens_to_ids(self.audio_token) |
| ) |
| if chat_template is None: |
| chat_template = self.chat_template |
| super().__init__(image_processor, audio_processor, tokenizer, |
| chat_template=chat_template) |
| |
| def __call__(self, |
| text: Union[str, List[str], None] = None, |
| messages: Union[List[Dict], None] = None, |
| images: Union[ImageInput, None] = None, |
| videos: Union[VideoInput, None] = None, |
| audio: Union[AudioInput, None] = None, |
| do_train: bool = False, |
| add_generation_prompt: bool = False, |
| **kwargs: Unpack[QualityvProcessingKwargs] |
| ): |
| ''' |
| input |
| messages: list of dicts |
| example: |
| [ |
| {"role": "user" |
| "content": [ |
| {"type": "text", "text": "Hello, how are you?"}, |
| {"type": "image", "image":xxx)}, |
| {"type": "video", "video": xxx}, |
| ] |
| }, |
| ... |
| ] |
| output: |
| input_ids |
| attention_mask |
| pixel_values, |
| pixel_values_videos |
| audio_values |
| labels, default None, |
| ''' |
| input_ids = [] |
| pixel_values = [] |
| pixel_values_videos = [] |
| audio_values = [] |
| labels = None |
| |
| if not text and not messages: |
| raise ValueError("At least one of text or messages must be provided.") |
| if messages: |
| text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, |
| tokenize=False) |
| if isinstance(text, list): |
| text = text[0] |
| image_list = self.fill_modal_list(self.image_token, "image", messages, images, text) |
| image_list = self.process_str_in_modal_list(image_list, "image", **kwargs.get("images_kwargs", {})) |
| |
| if image_list and self.image_token in text: |
| tokens_per_image = kwargs.get("images_kwargs", {}).get("tokens_per_image", 197) |
| text = text.replace(self.image_token, tokens_per_image * self.image_token) |
| pixel_values = self.image_processor(images=image_list, return_tensors="pt")["pixel_values"] |
| |
| video_list = self.fill_modal_list(self.video_token, "video", messages, videos, text) |
| video_list = self.process_str_in_modal_list(video_list, "video", **kwargs.get("videos_kwargs", {})) |
| |
| if video_list and self.video_token in text: |
| tokens_per_frame = kwargs.get("videos_kwargs", {}).get("tokens_per_frame", 197) |
| video_frame_list = [] |
| for video, video_meta in video_list: |
| num_frames = video.shape[0] |
| replace_text = num_frames * tokens_per_frame * self.video_token |
| text = text.replace(self.video_token, replace_text, 1) |
| for frame in video: |
| video_frame_list.append(frame) |
| pixel_values_videos = self.image_processor(images=video_frame_list, return_tensors="pt")["pixel_values"] |
| |
| audio_list = self.fill_modal_list(self.audio_token, "audio", messages, audio, text) |
| audio_list = self.process_str_in_modal_list(audio_list, "audio", **kwargs.get("audio_kwargs", {})) |
| |
| if audio_list and self.audio_token in text: |
| audio_kwargs = kwargs.get("audio_kwargs", {}) |
| sampling_rate = audio_kwargs.get("sampling_rate", 16000) |
| tokens_per_audio = audio_kwargs.get("tokens_per_audio", 1500) |
| for audio in audio_list: |
| replace_text = tokens_per_audio * self.audio_token |
| text = text.replace(self.audio_token, replace_text, 1) |
| audio_values = self.audio_processor(audio_list, return_tensors="pt", sampling_rate=sampling_rate)["input_features"] |
| |
| input_ids = self.tokenizer(text).input_ids |
| if do_train: |
| labels = self.get_labels(input_ids) |
| labels = torch.tensor(labels, dtype=torch.long) |
| input_ids = torch.tensor(input_ids, dtype=torch.long) |
| return { |
| "input_ids": input_ids, |
| "pixel_values": pixel_values if len(pixel_values) > 0 else None, |
| "pixel_values_videos": pixel_values_videos if len(pixel_values_videos) > 0 else None, |
| "audio_values": audio_values if len(audio_values) > 0 else None, |
| "labels": labels |
| } |
| |
| def fill_modal_list(self, modal_token: str, model_type: str, messages: List[Dict], modal_values: Union[AudioInput, VideoInput, ImageInput, None], text: str) -> List[Union[AudioInput, VideoInput, ImageInput]]: |
| modal_list = [] |
| if modal_token in text: |
| if not modal_values and messages: |
| for msg in messages: |
| if msg.get("role") == "user": |
| for content in msg.get("content", []): |
| if content.get('type') == model_type: |
| modal_list.append(content.get(model_type)) |
| elif modal_values: |
| if isinstance(modal_values, str): |
| modal_list = [modal_values] |
| else: |
| modal_list = modal_values |
| return modal_list |
| |
| def process_str_in_modal_list(self, modal_list: list, modal_type: str, **modal_kwargs: dict): |
| new_modal_list = [] |
| if modal_list: |
| for modal_value in modal_list: |
| if isinstance(modal_value, str): |
| new_modal_value = self.load_modal_str(modal_value, modal_type, **modal_kwargs) |
| new_modal_list.append(new_modal_value) |
| else: |
| new_modal_list.append(modal_value) |
| return new_modal_list |
| |
| def load_modal_str(self, model_path_or_url: str, modal_type: str, **modal_kwargs): |
| if modal_type == "image": |
| load_func = load_image_str |
| elif modal_type == "video": |
| load_func = load_video_str |
| elif modal_type == "audio": |
| load_func = load_audio_str |
| else: |
| raise ValueError(f"Invalid modal type: {modal_type}") |
| return load_func(model_path_or_url, **modal_kwargs) |
| |
| def get_labels(self, input_ids: List[int]) -> List[int]: |
| label_start_token_ids = self.tokenizer(self.label_start_text, add_special_tokens=False)["input_ids"] |
| label_end_token_ids = self.tokenizer(self.label_end_text, add_special_tokens=False)["input_ids"] |
|
|
| labels = [-100] * len(input_ids) |
|
|
| i = 0 |
| while i < len(input_ids): |
| |
| if input_ids[i : i + len(label_start_token_ids)] == label_start_token_ids: |
| |
| start_response = i + len(label_start_token_ids) |
| |
| j = start_response |
| found_end = False |
| while j < len(input_ids): |
| if input_ids[j : j + len(label_end_token_ids)] == label_end_token_ids: |
| end_response = j + len(label_end_token_ids) |
| found_end = True |
| break |
| j += 1 |
|
|
| if found_end: |
| |
| labels[start_response:end_response] = input_ids[start_response:end_response] |
| |
| i = end_response |
| continue |
| else: |
| |
| break |
| else: |
| i += 1 |
| pad_token_id = self.tokenizer.pad_token_id |
| if pad_token_id is not None: |
| for i in range(len(labels)): |
| if labels[i] == pad_token_id: |
| labels[i] = -100 |
| return labels |
| |
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
| |
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
| |
| |
| |
|
|