qualityv-0606 / processing_qualityv.py
CyberBoyNull's picture
Upload folder
cb65f9f verified
from typing import Union, Optional, List, Dict, Tuple, Callable
from transformers.processing_utils import (ProcessorMixin,
VideosKwargs,
AudioKwargs,
ImagesKwargs,
TextKwargs,
ProcessingKwargs,
Unpack)
import numpy as np
import decord
import torch
import PIL
from transformers.audio_utils import load_audio
from transformers.image_utils import load_image, load_video
from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoTokenizer
def load_audio_str(audio_path_or_url: str, sampling_rate: int = 16000) -> np.ndarray:
audio = load_audio(audio_path_or_url, sampling_rate=sampling_rate)
return audio
def load_video_str(video_path_or_url: str, num_frames: int = 4, fps: int = None) -> List[np.ndarray]:
video = load_video(video_path_or_url, num_frames=num_frames, fps=fps,
backend="decord")
return video
def load_image_str(image_path_or_url: str) -> List[np.ndarray]:
image = load_image(image_path_or_url)
return image
ImageInput = Union[
# same as transformers.image_utils.ImageInput
"PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"],
# image urls, or image_paths
str, list[str]
]
VideoInput = Union[
# same as transformers.image_utils.VideoInput
list["PIL.Image.Image"], "np.ndarray", "torch.Tensor", list["np.ndarray"],
list["torch.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarray"]],
list[list["torch.Tensor"]],
# video urls, or video_paths
str, list[str], list[list[str]]
]
AudioInput = Union[
# same as transformers.audio_utils.AudioInput
np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"], # noqa: F821
# audio urls, or audio_paths
str, list[str]
]
class QualityvImageKwargs(ImagesKwargs):
tokens_per_image: int = 197
class QualityvVideoKwargs(VideosKwargs):
num_frames: Union[int, None] = 4
fps: Union[int, None] = None
tokens_per_frame: int = 197
class QualityvAudioKwargs(AudioKwargs):
sampling_rate: Union[int, None] = 16000
tokens_per_audio: int = 1500
class QualityvProcessingKwargs(ProcessingKwargs):
images_kwargs: QualityvImageKwargs
videos_kwargs: QualityvVideoKwargs
audio_kwargs: QualityvAudioKwargs
text_kwargs: TextKwargs
class QualityvProcessor(ProcessorMixin):
attributes = ["image_processor",
"audio_processor",
"tokenizer"]
image_processor_class = "AutoImageProcessor"
audio_processor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endif %}<|im_start|>{{ message['role'] }}
{% if message['content'] is string %}{{ message['content'] }}<|im_end|>
{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_vision_id %}Audio {{ audio_count.value }}: {% endif %}<|vision_start|><|audio_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
{% endif %}"""
def __init__(self, tokenizer=None,
image_processor=None,
audio_processor=None,
chat_template=None,
image_token="<|image_pad|>",
video_token="<|video_pad|>",
audio_token="<|audio_pad|>",
label_start_text="<|im_start|>assistant\n",
label_end_text="<|im_end|>\n",
**kwargs):
self.image_token = image_token if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = video_token if not hasattr(tokenizer, "video_token") else tokenizer.video_token
self.audio_token = audio_token if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token
self.label_start_text = label_start_text
self.label_end_text = label_end_text
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.audio_token_id = (
tokenizer.audio_token_id
if getattr(tokenizer, "audio_token_id", None)
else tokenizer.convert_tokens_to_ids(self.audio_token)
)
if chat_template is None:
chat_template = self.chat_template
super().__init__(image_processor, audio_processor, tokenizer,
chat_template=chat_template)
def __call__(self,
text: Union[str, List[str], None] = None,
messages: Union[List[Dict], None] = None,
images: Union[ImageInput, None] = None,
videos: Union[VideoInput, None] = None,
audio: Union[AudioInput, None] = None,
do_train: bool = False,
add_generation_prompt: bool = False,
**kwargs: Unpack[QualityvProcessingKwargs]
):
'''
input
messages: list of dicts
example:
[
{"role": "user"
"content": [
{"type": "text", "text": "Hello, how are you?"},
{"type": "image", "image":xxx)},
{"type": "video", "video": xxx},
]
},
...
]
output:
input_ids
attention_mask
pixel_values,
pixel_values_videos
audio_values
labels, default None,
'''
input_ids = []
pixel_values = []
pixel_values_videos = []
audio_values = []
labels = None
if not text and not messages:
raise ValueError("At least one of text or messages must be provided.")
if messages:
text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt,
tokenize=False)
if isinstance(text, list):
text = text[0]
image_list = self.fill_modal_list(self.image_token, "image", messages, images, text)
image_list = self.process_str_in_modal_list(image_list, "image", **kwargs.get("images_kwargs", {}))
# replace image_token with num_images * num_image_token * image_token
if image_list and self.image_token in text:
tokens_per_image = kwargs.get("images_kwargs", {}).get("tokens_per_image", 197)
text = text.replace(self.image_token, tokens_per_image * self.image_token)
pixel_values = self.image_processor(images=image_list, return_tensors="pt")["pixel_values"]
video_list = self.fill_modal_list(self.video_token, "video", messages, videos, text)
video_list = self.process_str_in_modal_list(video_list, "video", **kwargs.get("videos_kwargs", {}))
# replace video_token with num_videos * num_video_token * video_token
if video_list and self.video_token in text:
tokens_per_frame = kwargs.get("videos_kwargs", {}).get("tokens_per_frame", 197)
video_frame_list = []
for video, video_meta in video_list:
num_frames = video.shape[0]
replace_text = num_frames * tokens_per_frame * self.video_token
text = text.replace(self.video_token, replace_text, 1)
for frame in video:
video_frame_list.append(frame)
pixel_values_videos = self.image_processor(images=video_frame_list, return_tensors="pt")["pixel_values"]
audio_list = self.fill_modal_list(self.audio_token, "audio", messages, audio, text)
audio_list = self.process_str_in_modal_list(audio_list, "audio", **kwargs.get("audio_kwargs", {}))
# replace audio_token with num_audio_tokens * audio_token
if audio_list and self.audio_token in text:
audio_kwargs = kwargs.get("audio_kwargs", {})
sampling_rate = audio_kwargs.get("sampling_rate", 16000)
tokens_per_audio = audio_kwargs.get("tokens_per_audio", 1500)
for audio in audio_list:
replace_text = tokens_per_audio * self.audio_token
text = text.replace(self.audio_token, replace_text, 1)
audio_values = self.audio_processor(audio_list, return_tensors="pt", sampling_rate=sampling_rate)["input_features"]
input_ids = self.tokenizer(text).input_ids
if do_train:
labels = self.get_labels(input_ids)
labels = torch.tensor(labels, dtype=torch.long)
input_ids = torch.tensor(input_ids, dtype=torch.long)
return {
"input_ids": input_ids,
"pixel_values": pixel_values if len(pixel_values) > 0 else None,
"pixel_values_videos": pixel_values_videos if len(pixel_values_videos) > 0 else None,
"audio_values": audio_values if len(audio_values) > 0 else None,
"labels": labels
}
def fill_modal_list(self, modal_token: str, model_type: str, messages: List[Dict], modal_values: Union[AudioInput, VideoInput, ImageInput, None], text: str) -> List[Union[AudioInput, VideoInput, ImageInput]]:
modal_list = []
if modal_token in text:
if not modal_values and messages:
for msg in messages:
if msg.get("role") == "user":
for content in msg.get("content", []):
if content.get('type') == model_type:
modal_list.append(content.get(model_type))
elif modal_values:
if isinstance(modal_values, str):
modal_list = [modal_values]
else:
modal_list = modal_values
return modal_list
def process_str_in_modal_list(self, modal_list: list, modal_type: str, **modal_kwargs: dict):
new_modal_list = []
if modal_list:
for modal_value in modal_list:
if isinstance(modal_value, str):
new_modal_value = self.load_modal_str(modal_value, modal_type, **modal_kwargs)
new_modal_list.append(new_modal_value)
else:
new_modal_list.append(modal_value)
return new_modal_list
def load_modal_str(self, model_path_or_url: str, modal_type: str, **modal_kwargs):
if modal_type == "image":
load_func = load_image_str
elif modal_type == "video":
load_func = load_video_str
elif modal_type == "audio":
load_func = load_audio_str
else:
raise ValueError(f"Invalid modal type: {modal_type}")
return load_func(model_path_or_url, **modal_kwargs)
def get_labels(self, input_ids: List[int]) -> List[int]:
label_start_token_ids = self.tokenizer(self.label_start_text, add_special_tokens=False)["input_ids"]
label_end_token_ids = self.tokenizer(self.label_end_text, add_special_tokens=False)["input_ids"]
labels = [-100] * len(input_ids)
i = 0
while i < len(input_ids):
# Look for the assistant's response start marker.
if input_ids[i : i + len(label_start_token_ids)] == label_start_token_ids:
# The actual response begins after the start marker.
start_response = i + len(label_start_token_ids)
# Now, search for the end marker.
j = start_response
found_end = False
while j < len(input_ids):
if input_ids[j : j + len(label_end_token_ids)] == label_end_token_ids:
end_response = j + len(label_end_token_ids) # Mark the end of the response (excluding the end marker)
found_end = True
break
j += 1
if found_end:
# Copy the tokens corresponding to the assistant's response into labels.
labels[start_response:end_response] = input_ids[start_response:end_response]
# Advance i beyond the end marker.
i = end_response
continue # Continue scanning for the next assistant response.
else:
# If no end marker is found, break out of the loop.
break
else:
i += 1
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None:
for i in range(len(labels)):
if labels[i] == pad_token_id:
labels[i] = -100
return labels
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)