| | import ast |
| | import copy |
| | import datetime |
| | import gc |
| | import io |
| | import json |
| | import math |
| | import mimetypes |
| | import os |
| | import random |
| | import re |
| | import sys |
| | import tarfile |
| | import tempfile |
| | import zipfile |
| | from collections import defaultdict, deque |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import av |
| | import cv2 |
| | import numpy as np |
| | import PIL |
| | import pkg_resources |
| | import scipy.signal as scsig |
| | import torch |
| | from decord import VideoReader, cpu |
| | from PIL import Image, ImageDraw |
| | from smart_open import open |
| | from torchvision.transforms.functional import to_tensor |
| |
|
| | from hcxvlm.dataset.base_dataset import image_decoder |
| | from hcxvlm.dataset.hcx_vision_prompter import HCXVisionPrompter |
| |
|
| | CHOICES = list(map(chr, range(97, 123))) |
| | IGNORE_INDEX = -100 |
| | DEFAULT_SAMPLE_RATE = 16000 |
| | MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 |
| | DEFAULT_VOLUME_LEVEL = 10 ** (-26 / 20) |
| |
|
| | hcx_vision_prompter = HCXVisionPrompter() |
| |
|
| |
|
| | def hpf_normalize( |
| | wav: np.ndarray, |
| | sr: int = DEFAULT_SAMPLE_RATE, |
| | volume_level: float = DEFAULT_VOLUME_LEVEL, |
| | ) -> np.ndarray: |
| | assert (wav**2).mean() > 0, "Error in the wav file" |
| |
|
| | filter_ = scsig.butter(2, 70, "highpass", fs=sr, output="sos") |
| | wav = scsig.sosfilt(filter_, wav) |
| | wav = wav.astype(np.float32) |
| |
|
| | gain = min(volume_level / (wav**2).mean() ** 0.5, 1 / np.max(np.abs(wav))) |
| | wav *= gain |
| | return wav |
| |
|
| |
|
| | def convert_bboxes(img, img_meta): |
| | for k, v in img_meta.items(): |
| | if k == "region": |
| | bbox_key = "bbox" if "bbox" in img_meta[k] else "boundingBox" |
| | img_meta[k] = reform_bbox( |
| | img_meta[k][bbox_key], img.size, format=img_meta[k]["format"] |
| | ) |
| | return img_meta |
| |
|
| |
|
| | def reform_bbox(bbox, image_size, format="REL_XYXY"): |
| | w, h = image_size |
| | if format == "REL_XYXY": |
| | x1, y1, x2, y2 = bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h |
| | elif format == "REL_XYWH": |
| | x1, y1 = bbox[0] * w, bbox[1] * h |
| | x2, y2 = x1 + bbox[2] * w, y1 + bbox[3] * h |
| | else: |
| | raise NotImplementedError |
| | new_bbox = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] |
| | return new_bbox |
| |
|
| |
|
| | def generate_random_color(use_alpha=True, seed=None): |
| | if seed is None: |
| | seed = np.random.default_rng() |
| |
|
| | if use_alpha: |
| | color_list = [ |
| | ("빨강", (255, 127, 127, 100)), |
| | ("노랑", (255, 255, 127, 100)), |
| | ("초록", (127, 255, 125, 100)), |
| | ("하늘", (127, 255, 255, 100)), |
| | ("파랑", (127, 127, 255, 100)), |
| | ("보라", (255, 127, 255, 100)), |
| | ] |
| | else: |
| | color_list = [ |
| | ("빨강", (255, 0, 0)), |
| | ("노랑", (255, 255, 0)), |
| | ("초록", (0, 128, 0)), |
| | ("하늘", (135, 206, 235)), |
| | ("파랑", (0, 0, 255)), |
| | ("보라", (128, 0, 128)), |
| | ] |
| | return color_list[seed.integers(0, len(color_list))] |
| |
|
| |
|
| | EN_COLOR = { |
| | "빨강": "red", |
| | "노랑": "yellow", |
| | "초록": "green", |
| | "하늘": "sky blue", |
| | "파랑": "blue", |
| | "보라": "purple", |
| | } |
| |
|
| |
|
| | def overlay_rectangle(image, words, lang, seed=None): |
| | color_str, color = generate_random_color(seed=seed) |
| | draw = ImageDraw.Draw(image, "RGBA") |
| | for word in words: |
| | shape_rect = word["bbox"] |
| | shape_rect = [(round(x[0]), round(x[1])) for x in shape_rect] |
| | draw.polygon(shape_rect, color) |
| | del draw |
| | if lang == "en": |
| | color_str = EN_COLOR[color_str] |
| | return image, color_str |
| |
|
| |
|
| | def convert_tags_for_video(img, json): |
| | """video 데이터에는 <image_xx> 태그 대신 <video_00> tag가 있음. |
| | img 숫자 만큼 <video_00> tag 대신 <image_xx> tag를 변환하여 넣음 |
| | """ |
| | image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))]) |
| | for json_key in json: |
| | if "qa_pairs" in json_key: |
| | new_qa_pairs = [] |
| | for qa_pair in json[json_key]: |
| | question = qa_pair[0] |
| | question = question.replace("<video_00>", image_tag) |
| | new_qa_pairs.append([question, qa_pair[1]]) |
| | json[json_key] = new_qa_pairs |
| |
|
| | return img, json |
| |
|
| |
|
| | def sampling_multiturn_single_img( |
| | seq, |
| | count, |
| | multiturn_preserve_order=True, |
| | multiturn_continuous=False, |
| | is_train: bool = True, |
| | seed=None, |
| | ): |
| | if seed is None: |
| | seed = np.random.default_rng() |
| | n_sample = min(count, len(seq)) |
| |
|
| | if multiturn_continuous: |
| | if len(seq) <= n_sample: |
| | start_index = 0 |
| | else: |
| | start_index = seed.integers(0, len(seq) - n_sample) |
| | indices = range(start_index, start_index + n_sample) |
| | elif multiturn_preserve_order: |
| | indices = sorted(seed.choice(range(len(seq)), size=n_sample, replace=False)) |
| | else: |
| | indices = seed.choice(range(len(seq)), size=n_sample, replace=False) |
| |
|
| | return [seq[i] for i in indices] |
| |
|
| |
|
| | def draw_bbox(image, bbox, lang="en", line_width=5, seed=None): |
| | if seed is None: |
| | seed = np.random.default_rng() |
| | color_str, color = generate_random_color(use_alpha=False, seed=seed) |
| | draw = ImageDraw.Draw(image, "RGB") |
| | rect_bbox = (bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]) |
| | draw.rectangle(rect_bbox, outline=color, width=line_width) |
| | del draw |
| | if lang == "en": |
| | color_str = EN_COLOR[color_str] |
| | return image, color_str |
| |
|
| |
|
| | def bbox_process(bbox, detection_precision=2): |
| | bbox_str = "[" |
| | for idx, point in enumerate(bbox): |
| | if idx % 2 == 0: |
| | normalized = point |
| | else: |
| | normalized = point |
| |
|
| | if idx < len(bbox) - 1: |
| | bbox_str += format(normalized, ".2f") + ", " |
| | else: |
| | bbox_str += format(normalized, ".2f") |
| | bbox_str += "]" |
| | return bbox_str |
| |
|
| |
|
| | def load_txt(file_path): |
| | lines_list = [] |
| | with open(file_path, "r") as file: |
| | for line in file: |
| | lines_list.append(line.replace("\\n", "\n").strip()) |
| | return lines_list |
| |
|
| |
|
| | def convert_format_for_multi_image( |
| | img, json, convert_key_list=["words", "text", "objects", "entities"] |
| | ): |
| | """single image dataset 과 multi image dataset 에서 읽어온 img, json format 이 다름. |
| | 따라서 single image dataset 에서 읽어온 img, json 을 multi image dataset 의 format (dict) 으로 convert |
| | """ |
| | is_multi_image_dataset = isinstance(img, dict) |
| | if not is_multi_image_dataset: |
| | img = {"00": img} |
| |
|
| | for convert_key in convert_key_list: |
| | if convert_key in json: |
| | json[convert_key] = {"00": json[convert_key]} |
| |
|
| | for json_key in json: |
| | if "region" in json_key: |
| | json[json_key] = {"00": json[json_key]} |
| | else: |
| | for convert_key in convert_key_list: |
| | if convert_key in json: |
| | if isinstance(json[convert_key], list): |
| | json[convert_key] = {"00": json[convert_key]} |
| |
|
| | for json_key in json: |
| | if "region" in json_key: |
| | if isinstance(json[json_key], list): |
| | json[json_key] = {"00": json[json_key]} |
| |
|
| | return is_multi_image_dataset, img, json |
| |
|
| |
|
| | class ConditionalError(Exception): |
| | def __init__(self, message="Our assertion error"): |
| | super().__init__(message) |
| |
|
| |
|
| | def get_wds_default_config(default_config, existing_default_config=None): |
| | if existing_default_config is None: |
| | default_config_check_dict = { |
| | "subtask": "", |
| | "reasoning": False, |
| | "use_task_prompt": True, |
| | "get_random": True, |
| | "add_instruct_prompts": [], |
| | "multiturn_n_samples": 0, |
| | "multiturn_preserve_order": True, |
| | "multiturn_continuous": False, |
| | "insert_ocr": 200, |
| | "ocr_filter_strategy": "confidence", |
| | "ocr_use_ratio": 1.0, |
| | "entity_top_k": 0, |
| | "entity_keyword_threshold": 100, |
| | "entity_keyword_fashion_threshold": 100, |
| | "entity_use_ratio": 0.0, |
| | "llava_pretrain": False, |
| | "random_system_prob": 0.0, |
| | "random_system_path": "", |
| | "random_tool_prob": 0.005, |
| | } |
| | else: |
| | default_config_check_dict = existing_default_config |
| | if default_config is None: |
| | default_config = default_config_check_dict |
| | else: |
| | for key, value in default_config_check_dict.items(): |
| | if key not in default_config: |
| | default_config[key] = value |
| | return default_config |
| |
|
| |
|
| | def get_datalake_default_config(default_config): |
| | default_config_check_dict = { |
| | "multiturn_n_samples": 0, |
| | "multiturn_preserve_order": True, |
| | "multiturn_continuous": True, |
| | "insert_ocr": 0, |
| | "ocr_filter_strategy": "confidence", |
| | "entity_top_k": 0, |
| | "entity_keyword_threshold": 0, |
| | "entity_keyword_fashion_threshold": 0, |
| | "entity_use_ratio": 0.0, |
| | "ocr_use_ratio": 0.0, |
| | "llava_pretrain": False, |
| | "random_system_prob": 0.0, |
| | "random_system_path": "", |
| | "random_tool_prob": 0.005, |
| | } |
| | if default_config is None: |
| | default_config = default_config_check_dict |
| | else: |
| | for key, value in default_config_check_dict.items(): |
| | if key not in default_config: |
| | default_config[key] = value |
| | return default_config |
| |
|
| |
|
| | @dataclass |
| | class Processed_sample: |
| | input_str: str = None |
| | input_ids: list = None |
| | label_ids: list = None |
| | imgs: list = None |
| | discrete_imgs: list = None |
| | videos: list = None |
| | videos_duration: List[dict] = None |
| | video_audios: list = None |
| | audios: list = None |
| | audios_duration: List[dict] = None |
| | discrete_audios: list = None |
| | sample_mm_counter: dict = None |
| |
|
| |
|
| | from hcxvlm.dataset.bbox_processor import ( |
| | extract_bboxes, |
| | insert_bboxes_to_json, |
| | is_bbox_padded, |
| | ) |
| |
|
| |
|
| | class Preprocessor: |
| | prompt_head = "" |
| | va_prefix = "\n<|im_start|>" |
| | new_line = "\n" |
| | turn_prefix = "<|im_start|>" |
| | turn_suffix = "<|im_end|>" |
| | mime_start = "<|mime_start|>" |
| | mime_end = "<|mime_end|>" |
| | aux_img_start = "<|image_aux_start|>" |
| | aux_img_end = "<|image_aux_end|>" |
| | aux_video_start = "<|video_aux_start|>" |
| | aux_video_end = "<|video_aux_end|>" |
| | aux_audio_start = "<|audio_aux_start|>" |
| | aux_audio_end = "<|audio_aux_end|>" |
| | image_start = "<|image_start|>" |
| | image_end = "<|image_end|>" |
| | image_pad = "<|IMAGE_PAD|>" |
| | video_start = "<|video_start|>" |
| | video_end = "<|video_end|>" |
| | video_pad = "<|VIDEO_PAD|>" |
| | audio_start = "<|audio_start|>" |
| | audio_end = "<|audio_end|>" |
| | audio_pad = "<|AUDIO_PAD|>" |
| | discrete_image_start = "<|discrete_image_start|>" |
| | discrete_image_end = "<|discrete_image_end|>" |
| | discrete_image_pad = "<|DISCRETE_IMAGE_PAD|>" |
| | video_audio_pad = "<|VIDEO_AUDIO_PAD|>" |
| | discrete_audio_start = "<|discrete_audio_start|>" |
| | discrete_audio_end = "<|discrete_audio_end|>" |
| | discrete_audio_pad = "<|DISCRETE_AUDIO_PAD|>" |
| |
|
| | discrete_image_eol = "<|vision_eol|>" |
| | discrete_image_eof = "<|vision_eof|>" |
| | discrete_image_ratios = { |
| | (1, 1): "<|vision_ratio_1:1|>", |
| | (1, 2): "<|vision_ratio_1:2|>", |
| | (2, 1): "<|vision_ratio_2:1|>", |
| | (3, 4): "<|vision_ratio_3:4|>", |
| | (4, 3): "<|vision_ratio_4:3|>", |
| | (3, 5): "<|vision_ratio_3:5|>", |
| | (5, 3): "<|vision_ratio_5:3|>", |
| | (4, 5): "<|vision_ratio_4:5|>", |
| | (5, 4): "<|vision_ratio_5:4|>", |
| | (6, 9): "<|vision_ratio_6:9|>", |
| | (9, 6): "<|vision_ratio_9:6|>", |
| | (9, 16): "<|vision_ratio_9:16|>", |
| | (16, 9): "<|vision_ratio_16:9|>", |
| | } |
| |
|
| | aux_vid_prompt = ( |
| | "다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. " |
| | ) |
| | aux_audio_prompt = ( |
| | "다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. " |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | tokenizer=None, |
| | prepare_input_fn=None, |
| | prepare_audio_input_fn=None, |
| | sample_min_length=0, |
| | decoder_max_length=None, |
| | mode="train", |
| | model=None, |
| | datalake_default_config=None, |
| | wds_default_config=None, |
| | video_config=None, |
| | train_video=False, |
| | train_audio=False, |
| | sequence_parallel_size=1, |
| | video_audio_compressor_type=None, |
| | ): |
| | self.sequence_parallel_size = sequence_parallel_size |
| | if sequence_parallel_size > 1: |
| | self.rng = np.random.default_rng(seed=42) |
| | else: |
| | self.rng = np.random.default_rng() |
| |
|
| | if model is not None: |
| | tokenizer = model.tokenizer |
| | decoder_max_length = 16000 |
| |
|
| | if model is not None and prepare_input_fn is None: |
| | raise "please give ImageProcessor!" |
| |
|
| | self.prepare_input_fn = prepare_input_fn |
| | self.prepare_audio_input_fn = prepare_audio_input_fn |
| | try: |
| | from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import ( |
| | Qwen2_5_VLProcessor, |
| | ) |
| |
|
| | self.is_qwen_visual = isinstance(prepare_input_fn, Qwen2_5_VLProcessor) |
| | except Exception as e: |
| | self.is_qwen_visual = False |
| | try: |
| | if not self.is_qwen_visual: |
| | from hcxvlm.models.processing_vlm import HCXVisionV2Processor |
| |
|
| | self.is_qwen_visual = isinstance(prepare_input_fn, HCXVisionV2Processor) |
| | except Exception as e: |
| | self.is_qwen_visual = False |
| | assert self.is_qwen_visual, "qwen2.5-vl visual prepare_input_fn import error" |
| |
|
| | self.video_max_num_frames = ( |
| | video_config["video_max_num_frames"] |
| | if video_config and "video_max_num_frames" in video_config |
| | else 120 |
| | ) |
| | self.video_max_pixels = ( |
| | video_config["video_max_pixels"] |
| | if video_config and "video_max_pixels" in video_config |
| | else 378 * 378 |
| | ) |
| |
|
| | self.tokenizer = tokenizer |
| | self.sample_min_length = sample_min_length |
| | self.decoder_max_length = decoder_max_length |
| | self.mode = mode |
| | self.default_config = get_datalake_default_config(datalake_default_config) |
| | self.wds_default_config = get_wds_default_config(wds_default_config) |
| | self.train_video = train_video |
| | self.train_audio = train_audio |
| | self.video_audio_compressor_type = video_audio_compressor_type |
| |
|
| | self.img_token = self.tokenizer.encode(Preprocessor.image_pad)[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.image_pad)) == 1 |
| | ), "img_token is not configured in tokenizer" |
| |
|
| | self.discrete_image_token = self.tokenizer.encode( |
| | Preprocessor.discrete_image_pad |
| | )[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.discrete_image_pad)) == 1 |
| | ), "discrete_image_token is not configured in tokenizer" |
| |
|
| | self.discrete_image_eol_token = self.tokenizer.encode( |
| | Preprocessor.discrete_image_eol |
| | )[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.discrete_image_eol)) == 1 |
| | ), "discrete_image_eol_token is not configured in tokenizer" |
| |
|
| | self.discrete_image_eof_token = self.tokenizer.encode( |
| | Preprocessor.discrete_image_eof |
| | )[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.discrete_image_eof)) == 1 |
| | ), "discrete_image_eof_token is not configured in tokenizer" |
| |
|
| | self.discrete_image_ratio_tokens = dict() |
| | for ratio, token_str in Preprocessor.discrete_image_ratios.items(): |
| | token_id = self.tokenizer.encode(token_str)[0] |
| | assert ( |
| | len(self.tokenizer.encode(token_str)) == 1 |
| | ), f"discrete_image_ratio_token {token_str} is not configured in tokenizer" |
| | self.discrete_image_ratio_tokens[ratio] = token_id |
| |
|
| | self.video_token = self.tokenizer.encode(Preprocessor.video_pad)[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.video_pad)) == 1 |
| | ), "video_token is not configured in tokenizer" |
| |
|
| | self.video_audio_token = self.tokenizer.encode(Preprocessor.video_audio_pad)[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.video_audio_pad)) == 1 |
| | ), "video_audio_token is not configured in tokenizer" |
| |
|
| | def resize_min_edge(img: Image.Image) -> Image.Image: |
| | w, h = img.size |
| | min_size = 28 |
| | if min(w, h) >= min_size: |
| | return img |
| | if w < h: |
| | new_w = min_size |
| | new_h = int(h * (min_size / w)) |
| | else: |
| | new_h = min_size |
| | new_w = int(w * (min_size / h)) |
| | return img.resize((new_w, new_h), Image.BICUBIC) |
| |
|
| | self._resize_min_edge = resize_min_edge |
| |
|
| | self.audio_token = self.tokenizer.encode(Preprocessor.audio_pad)[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.audio_pad)) == 1 |
| | ), "audio_token is not configured in tokenizer" |
| |
|
| | self.discrete_audio_token = self.tokenizer.encode( |
| | Preprocessor.discrete_audio_pad |
| | )[0] |
| | assert ( |
| | len(self.tokenizer.encode(Preprocessor.discrete_audio_pad)) == 1 |
| | ), "audio_token is not configured in tokenizer" |
| |
|
| | from hcxvlm.dataset.json_processer import generate_prompt |
| |
|
| | self.generate_prompt = generate_prompt |
| |
|
| | self.mimes = list() |
| | for mime_filename in [ |
| | "words_alpha.txt", |
| | "korean-366506-wordslistUnique.txt", |
| | ]: |
| | self.mimes += ( |
| | pkg_resources.resource_string( |
| | "hcxvlm", f"dataset/hcx_vision_prompter/prompts/{mime_filename}" |
| | ) |
| | .decode("utf-8") |
| | .split("\r\n") |
| | ) |
| |
|
| | self.common_tools = [] |
| | try: |
| | common_tools_bytes = pkg_resources.resource_string( |
| | "hcxvlm", |
| | "dataset/hcx_vision_prompter/prompts/common_tools.jsonl", |
| | ) |
| | for line in common_tools_bytes.decode("utf-8").splitlines(): |
| | line = line.strip() |
| | if not line: |
| | continue |
| | try: |
| | self.common_tools.append(json.loads(line)) |
| | except Exception: |
| | continue |
| | except Exception: |
| | self.common_tools = [] |
| |
|
| | self.random_system_prompt = "" |
| | if self.default_config["random_system_path"] != "": |
| | self.random_system_prompt = "" |
| | with open(self.default_config["random_system_path"], "r") as f: |
| | for line in f: |
| | self.random_system_prompt += line |
| |
|
| | if ( |
| | self.random_system_prompt != "" |
| | and self.wds_default_config["random_system_path"] != "" |
| | ): |
| | assert ( |
| | self.wds_default_config["random_system_path"] |
| | == self.default_config["random_system_path"] |
| | ), "random_system_path in both default_config and wds_default_config should be the same" |
| |
|
| | def _find_best_ratio_token(self, original_size): |
| | """Find the best ratio token based on original_size""" |
| | base_ratios = list(self.discrete_image_ratio_tokens.keys()) |
| | vision_aspect_ratios = [ |
| | r for ratio in base_ratios for r in [ratio, ratio[::-1]] |
| | ][1:] |
| |
|
| | if not isinstance(original_size, list) or len(original_size) != 2: |
| | return self.discrete_image_ratio_tokens[(1, 1)] |
| |
|
| | h, w = original_size |
| | if h == 0 or w == 0: |
| | return self.discrete_image_ratio_tokens[(1, 1)] |
| |
|
| | ratios = [i / j for i, j in vision_aspect_ratios] |
| |
|
| | best_size_idx = np.argmin([abs(w / h - r) for r in ratios]) |
| |
|
| | i, j = vision_aspect_ratios[best_size_idx] |
| | return self.discrete_image_ratio_tokens[(i, j)] |
| |
|
| | @classmethod |
| | def prompt_mime( |
| | cls, |
| | mimes: Optional[list[str]] = None, |
| | file_name: str = None, |
| | tag_idx: int = 1, |
| | fixed_mime: bool = False, |
| | is_video: bool = False, |
| | is_audio: bool = False, |
| | seed: np.random.Generator = None, |
| | ) -> list[dict]: |
| | assert mimes or file_name |
| |
|
| | if seed is None: |
| | seed = np.random.default_rng() |
| |
|
| | if file_name: |
| | name, ext = os.path.splitext(file_name) |
| | ext = ext.lstrip(".") |
| | elif fixed_mime: |
| | ext = "jpeg" |
| | name = mimes[tag_idx] |
| | elif not fixed_mime and seed is not None: |
| | ext = seed.choice(["png", "jpeg"]) |
| | name = mimes[seed.integers(0, len(mimes))] |
| | else: |
| | ext = "jpeg" |
| | name = mimes[tag_idx] |
| |
|
| | if is_video: |
| | ext_candidates = ["mp4", "mov", "avi", "webm"] |
| | if fixed_mime: |
| | ext = "mp4" |
| | elif ext not in ext_candidates: |
| | ext = seed.choice(ext_candidates) |
| |
|
| | filename = f"{name}.{ext}" |
| | mime_type = mimetypes.guess_type(filename)[0] |
| | mime_prompt = { |
| | "id": f"video_{str(tag_idx).zfill(2)}", |
| | "type": f"{mime_type}", |
| | "filename": f"{filename}", |
| | } |
| | return mime_prompt |
| |
|
| | if is_audio: |
| | ext_candidates = ["mp3", "wav", "aac", "flac", "pcm"] |
| | if fixed_mime: |
| | ext = "wav" |
| | elif ext not in ext_candidates: |
| | ext = seed.choice(ext_candidates) |
| |
|
| | filename = f"{name}.{ext}" |
| | mime_type = mimetypes.guess_type(filename)[0] |
| | mime_prompt = { |
| | "id": f"audio_{str(tag_idx).zfill(2)}", |
| | "type": f"{mime_type}", |
| | "filename": f"{filename}", |
| | } |
| | return mime_prompt |
| |
|
| | if file_name: |
| | filename = f"{name}.{ext}" |
| | mime_type = mimetypes.guess_type(filename)[0] |
| | mime_prompt = { |
| | "id": f"image_{str(tag_idx).zfill(2)}", |
| | "type": f"{mime_type}", |
| | "filename": f"{filename}", |
| | } |
| | else: |
| | mime_prompt = { |
| | "id": f"image_{str(tag_idx).zfill(2)}", |
| | "type": f"image/{ext}", |
| | "filename": f"{name}.{'jpg' if ext == 'jpeg' else 'png'}", |
| | } |
| | return mime_prompt |
| |
|
| | @classmethod |
| | def ocr_preprocess( |
| | cls, |
| | words: list[dict], |
| | n_insert_ocr_tokens: int = 2000, |
| | insert_ocr: int = 200, |
| | ocr_use_ratio: float = 0.5, |
| | tokenizer=None, |
| | seed=None, |
| | ) -> list[str]: |
| | if seed is None: |
| | seed = np.random.default_rng() |
| | if ocr_use_ratio < seed.random(): |
| | return None |
| | if insert_ocr == 0: |
| | return None |
| |
|
| | confidence_list = [] |
| | insert_ocr_prompt = [] |
| | for word in words: |
| | if "confidence" in word: |
| | confidence_list.append(word["confidence"]) |
| | has_ocr_confidence = len(confidence_list) >= insert_ocr |
| |
|
| | if len(words) <= insert_ocr or not has_ocr_confidence: |
| | insert_ocr_prompt += [ |
| | d["text"].strip() for d in words if d["text"].strip() |
| | ][:insert_ocr] |
| | else: |
| | confidence_threshold = 0.3 |
| | cnt = 0 |
| | for word in words: |
| | if word["text"] == "": |
| | continue |
| | if word["confidence"] >= confidence_threshold: |
| | insert_ocr_prompt.append(word["text"]) |
| | cnt += 1 |
| | if cnt >= insert_ocr: |
| | break |
| | ocr_inputs = " ".join(insert_ocr_prompt) |
| | if tokenizer: |
| | ocr_inputs = tokenizer.decode( |
| | tokenizer.encode(ocr_inputs)[:n_insert_ocr_tokens] |
| | ) |
| | return ocr_inputs |
| |
|
| | @classmethod |
| | def lens_preprocess( |
| | cls, |
| | lens: list[dict], |
| | entity_top_k: int = 100, |
| | entity_keyword_threshold: float = 0.0, |
| | entity_keyword_fashion_threshold: float = 0.0, |
| | entity_use_ratio: float = 0.0, |
| | seed=None, |
| | ): |
| | if seed is None: |
| | seed = np.random.default_rng() |
| | if seed.uniform(0, 1) > entity_use_ratio: |
| | return None |
| |
|
| | entities = lens |
| | filter_idx = [] |
| | insert_entity_prompt = {} |
| | for idx, entity in enumerate(entities): |
| | if entity["type"] != "naver_lens_api": |
| | filter_idx.append(idx) |
| | continue |
| | if ( |
| | isinstance(entity_keyword_threshold, (int, float)) |
| | and entity["confidence"] < entity_keyword_threshold |
| | ): |
| | filter_idx.append(idx) |
| | continue |
| | if ( |
| | isinstance(entity_keyword_fashion_threshold, (int, float)) |
| | and ("fashion" in entity["info"]["classes"]) |
| | and entity["confidence"] < entity_keyword_fashion_threshold |
| | ): |
| | filter_idx.append(idx) |
| | continue |
| |
|
| | entityvalue = [ |
| | keyword for idx, keyword in enumerate(entities) if idx not in filter_idx |
| | ] |
| | entityvalue = sorted(entityvalue, key=lambda x: x["confidence"], reverse=True) |
| |
|
| | important_entity_list = [] |
| | local_entity_str_list = [] |
| | keywords_and_bbox_per_detector = {} |
| | for keyword_dict in entityvalue[:entity_top_k]: |
| | object_class = "/".join(keyword_dict["info"]["classes"]) |
| | if object_class not in keywords_and_bbox_per_detector.keys(): |
| | keywords_and_bbox_per_detector[object_class] = [] |
| | keywords_and_bbox_per_detector[object_class].append(keyword_dict) |
| |
|
| | for object_class in keywords_and_bbox_per_detector.keys(): |
| | entities_per_object = keywords_and_bbox_per_detector[object_class] |
| | normalized_bbox = bbox_process( |
| | [*entities_per_object[0]["bbox"][0], *entities_per_object[0]["bbox"][2]] |
| | ) |
| | entities = [entity["text"] for entity in entities_per_object] |
| | if "context" in object_class: |
| | important_entity_list += entities |
| |
|
| | else: |
| | local_entity_str_list += [ |
| | str(normalized_bbox) + " " + ", ".join(entities) |
| | ] |
| | if len(important_entity_list) > 0: |
| | insert_entity_prompt["lens_keywords"] = ", ".join(important_entity_list) |
| | if len(local_entity_str_list) > 0: |
| | insert_entity_prompt["lens_local_keywords"] = " ".join( |
| | local_entity_str_list |
| | ) |
| |
|
| | return insert_entity_prompt |
| |
|
| | @classmethod |
| | def prompt_toollist( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | content: Optional[list[dict]] = None, |
| | ): |
| | assert content or turn |
| | if turn is None: |
| | turn = { |
| | "role": "tool_list", |
| | "content": content, |
| | } |
| |
|
| | toollist_str = ( |
| | cls.turn_prefix.strip() |
| | + turn["role"] |
| | + "\n" |
| | + turn["content"] |
| | + cls.turn_suffix |
| | ) |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += toollist_str |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(toollist_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| | return output |
| |
|
| | @classmethod |
| | def prompt_system( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | content: Optional[str] = None, |
| | seed=None, |
| | tool_prompt=None, |
| | system_role_count=0, |
| | ): |
| | assert content or turn |
| | if seed is None: |
| | seed = np.random.default_rng() |
| | if turn is None: |
| | system_prompt = content |
| | else: |
| | if "candidates" in turn: |
| | if len(turn["candidates"]) > 0: |
| | system_prompt = seed.choice(turn["candidates"]) |
| | if type(system_prompt) is dict: |
| | system_prompt = system_prompt["content"] |
| | else: |
| | system_prompt = "" |
| | elif isinstance(turn["content"], str): |
| | system_prompt = turn["content"] |
| | elif len(turn["content"]) > 0: |
| | system_prompt = seed.choice(turn["content"]) |
| |
|
| | system_str = cls.turn_prefix + turn["role"] + "\n" |
| | system_str += system_prompt.strip() |
| | if system_role_count == 0: |
| | if system_prompt.strip(): |
| | system_str += "\n" |
| | system_str += tool_prompt |
| | system_str += cls.turn_suffix |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += system_str |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(system_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| | return output |
| |
|
| | @classmethod |
| | def load_mm( |
| | cls, |
| | output, |
| | img_dir: str = "", |
| | turn: Optional[dict] = None, |
| | image_urls: Optional[list[str]] = None, |
| | image_metas: Optional[list[dict]] = None, |
| | video_urls: Optional[list[str]] = None, |
| | video_metas: Optional[list[dict]] = None, |
| | audio_urls: Optional[list[str]] = None, |
| | audio_metas: Optional[list[dict]] = None, |
| | prepare_input_fn=None, |
| | prepare_audio_input_fn=None, |
| | max_image_cnt=21, |
| | video_max_num_frames=None, |
| | video_max_pixels=None, |
| | use_audio: bool = False, |
| | audio_sample_rate: int = 16000, |
| | ): |
| | assert (image_urls or video_urls or audio_urls) or turn |
| | if turn is None: |
| | turn = {} |
| | if image_urls: |
| | turn.update({"image_urls": image_urls}) |
| | turn.update({"image_metas": image_metas}) |
| | if video_urls: |
| | turn.update({"video_urls": video_urls}) |
| | turn.update({"video_metas": video_metas}) |
| | if audio_urls: |
| | turn.update({"audio_urls": audio_urls}) |
| | turn.update({"audio_metas": audio_metas}) |
| |
|
| | if "video_urls" in turn: |
| | if len(turn["video_urls"]) and (prepare_input_fn is None): |
| | raise ConditionalError("video processing needs 'prepare_input_fn'") |
| |
|
| | if not isinstance(turn["content"], str): |
| | raise ConditionalError(f"turn['content'] must be a string") |
| |
|
| | turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"]) |
| | pattern = re.compile( |
| | r"<\|video\|>|<\|image\|>|<\|t2i_model_generation_target_discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>" |
| | ) |
| | tags = [match.group() for match in pattern.finditer(turn["content"])] |
| |
|
| | img_idx = 0 |
| | vid_idx = 0 |
| | aud_idx = 0 |
| |
|
| | if "image_urls" not in turn: |
| | turn["image_urls"] = [] |
| | if "video_urls" not in turn: |
| | turn["video_urls"] = [] |
| | if "audio_urls" not in turn: |
| | turn["audio_urls"] = [] |
| |
|
| | for tag in tags: |
| | if ( |
| | tag == "<|image|>" |
| | or tag == "<|t2i_model_generation_target_discrete_image|>" |
| | ): |
| | img_path = turn["image_urls"][img_idx] |
| |
|
| | if isinstance(img_path, str): |
| | if "#" in img_path: |
| | compression_path, img_path = img_path.split("#", 1) |
| | compression_path = os.path.join(img_dir, compression_path) |
| | assert compression_path[-4:] in [ |
| | ".zip", |
| | ".tar", |
| | ], f"unsupported compression format: {compression_path}" |
| |
|
| | with open(compression_path, "rb") as comp_file: |
| | if compression_path.endswith(".zip"): |
| | with zipfile.ZipFile(comp_file, "r") as zip_file: |
| | with zip_file.open(img_path) as img_file: |
| | img_binary = img_file.read() |
| | elif compression_path.endswith(".tar"): |
| | with tarfile.open( |
| | fileobj=comp_file, mode="r" |
| | ) as tar_file: |
| | img_file = tar_file.extractfile(img_path) |
| | img_binary = img_file.read() |
| | else: |
| | with open(os.path.join(img_dir, img_path), "rb") as f: |
| | img_binary = f.read() |
| | img = image_decoder(img_binary) |
| | else: |
| | if isinstance(img_path, (bytes, bytearray)): |
| | img = io.BytesIO(img_path) |
| | img = Image.open(img).convert("RGB") |
| | else: |
| | img = img_path |
| | if not isinstance(img, Image.Image): |
| | img = Image.fromarray(np.uint8(img)).convert("RGB") |
| |
|
| | if "image_metas" in turn and turn["image_metas"]: |
| | turn["image_metas"][img_idx] = convert_bboxes( |
| | img, turn["image_metas"][img_idx] |
| | ) |
| |
|
| | if tag == "<|image|>": |
| | output.imgs.append(img) |
| | output.discrete_imgs.append(img) |
| |
|
| | img_idx += 1 |
| | elif tag == "<|video|>": |
| | video_path = turn["video_urls"][vid_idx] |
| | if isinstance(video_path, str): |
| | if "#" in video_path: |
| | compression_path, video_path = video_path.split("#", 1) |
| | compression_path = os.path.join(img_dir, compression_path) |
| | assert compression_path[-4:] in [ |
| | ".zip", |
| | ".tar", |
| | ], f"unsupported compression format: {compression_path}" |
| |
|
| | with open(compression_path, "rb") as comp_file: |
| | if compression_path.endswith(".zip"): |
| | with zipfile.ZipFile(comp_file, "r") as zip_file: |
| | video_file = zip_file.open(video_path) |
| | video_binary = video_file.read() |
| | elif compression_path.endswith(".tar"): |
| | with tarfile.open( |
| | fileobj=comp_file, mode="r" |
| | ) as tar_file: |
| | video_file = tar_file.extractfile(video_path) |
| | video_binary = video_file.read() |
| | else: |
| | with open(os.path.join(img_dir, video_path), "rb") as f: |
| | video_binary = f.read() |
| | video_binary = io.BytesIO(video_binary) |
| | else: |
| | video_binary = video_path |
| |
|
| | assert isinstance(video_binary, io.BytesIO), "video binary read error" |
| |
|
| | try: |
| | from hcxvlm.dataset.qwen_vision_process import process_vision_info |
| | except: |
| | from qwen_vl_utils import process_vision_info |
| |
|
| | if video_max_num_frames is None: |
| | video_max_num_frames = 120 |
| | if video_max_pixels is None: |
| | video_max_pixels = 378 * 378 |
| |
|
| | messages = [ |
| | [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "video", |
| | "video": video_binary, |
| | "max_frames": video_max_num_frames, |
| | "max_pixels": video_max_pixels, |
| | } |
| | ], |
| | } |
| | ], |
| | ] |
| | _, videos, video_kwargs = process_vision_info( |
| | messages, |
| | return_video_kwargs=True, |
| | use_audio=use_audio, |
| | audio_sample_rate=audio_sample_rate, |
| | ) |
| | output.videos.append(videos[0]) |
| | video_len = round(videos[0].shape[0] / video_kwargs["fps"][0], 2) |
| | output.videos_duration.append( |
| | { |
| | "video_duration": f"{video_len}s", |
| | } |
| | ) |
| |
|
| | if use_audio and "audio_chunks" in video_kwargs: |
| | audio_chunks = video_kwargs["audio_chunks"][0] |
| | if audio_chunks is not None: |
| | output.video_audios.append(audio_chunks) |
| | else: |
| | output.video_audios.append([]) |
| | elif use_audio: |
| | output.video_audios.append([]) |
| |
|
| | vid_idx += 1 |
| |
|
| | elif tag == "<|audio|>" or tag == "<|discrete_audio|>": |
| | audio_path = turn["audio_urls"][aud_idx] |
| | if isinstance(audio_path, str): |
| | if "#" in audio_path: |
| | compression_path, inner_path = audio_path.split("#", 1) |
| | compression_path = os.path.join(img_dir, compression_path) |
| | assert compression_path[-4:] in [ |
| | ".zip", |
| | ".tar", |
| | ], f"unsupported compression format: {compression_path}" |
| | with open(compression_path, "rb") as comp_file: |
| | if compression_path.endswith(".zip"): |
| | with zipfile.ZipFile(comp_file, "r") as zip_file: |
| | with zip_file.open(inner_path) as audio_file: |
| | audio_binary = audio_file.read() |
| | elif compression_path.endswith(".tar"): |
| | with tarfile.open( |
| | fileobj=comp_file, mode="r" |
| | ) as tar_file: |
| | audio_file = tar_file.extractfile(inner_path) |
| | audio_binary = audio_file.read() |
| | else: |
| | with open(os.path.join(img_dir, audio_path), "rb") as f: |
| | audio_binary = f.read() |
| | audio_stream = io.BytesIO(audio_binary) |
| | else: |
| | if isinstance(audio_path, (bytes, bytearray)): |
| | audio_stream = io.BytesIO(audio_path) |
| | else: |
| | audio_stream = audio_path |
| |
|
| | try: |
| | import librosa |
| |
|
| | y, sr = librosa.load( |
| | audio_stream, sr=DEFAULT_SAMPLE_RATE, mono=True |
| | ) |
| | assert ( |
| | DEFAULT_SAMPLE_RATE == sr |
| | ), f"librosa resampling failed: {DEFAULT_SAMPLE_RATE} != {sr}" |
| | except Exception as e: |
| | raise ConditionalError( |
| | f"audio decoding failed for {audio_path}: {e}" |
| | ) |
| |
|
| | audio_duration = len(y) / sr |
| | if audio_duration < 0.5: |
| | raise ConditionalError( |
| | f"Audio too short ({audio_duration:.2f}s). Minimum 0.5s required." |
| | ) |
| | if audio_duration > 600: |
| | raise ConditionalError( |
| | f"Audio duration ({audio_duration:.2f}s) exceeds maximum allowed duration (600s)" |
| | ) |
| |
|
| | if len(y) < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES: |
| | raise ConditionalError( |
| | f"Audio too short ({len(y)} samples = {audio_duration:.4f}s < 0.1s). " |
| | f"Minimum {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES} samples required for CosyVoice encoder." |
| | ) |
| |
|
| | if not hasattr(output, "audios"): |
| | output.audios = [] |
| | if not hasattr(output, "discrete_audios"): |
| | output.discrete_audios = [] |
| |
|
| | normalized_y = hpf_normalize(y) |
| | normalized_y = torch.from_numpy(normalized_y).float() |
| |
|
| | output.discrete_audios.append(normalized_y) |
| | if tag == "<|audio|>": |
| |
|
| | output.audios.append(y) |
| | total_duration = len(y) / sr |
| | output.audios_duration.append( |
| | { |
| | "duration": f"{(total_duration):.2f}s", |
| | } |
| | ) |
| |
|
| | aud_idx += 1 |
| | else: |
| | raise ConditionalError( |
| | f"{tag} is not in ['<|image|>', '<|video|>', '<|audio|>']" |
| | ) |
| |
|
| | return output |
| |
|
| | @classmethod |
| | def prompt_user( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | content: Optional[str] = None, |
| | is_train=False, |
| | fixed_mime=False, |
| | insert_ocr=300, |
| | file_names: Optional[list[str]] = None, |
| | mimes: Optional[list[str]] = None, |
| | mm_tokens: Optional[list[str]] = None, |
| | words: Optional[list] = None, |
| | lens: Optional[list] = None, |
| | query_template: Optional[list[str]] = None, |
| | config: Optional[dict] = None, |
| | seed: np.random.Generator = None, |
| | ): |
| | assert content or turn |
| | if turn is None: |
| | image_metas = [ |
| | {"words": words[i], "lens": lens[i]} for i in range(len(words)) |
| | ] |
| | turn = { |
| | "content": content, |
| | "image_metas": image_metas, |
| | } |
| | if seed is None: |
| | seed = np.random.default_rng() |
| |
|
| | turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"]) |
| | turn["content"] = re.sub(r"<video_\d+>", "<|video|>", turn["content"]) |
| | turn["content"] = re.sub(r"<audio_\d+>", "<|audio|>", turn["content"]) |
| |
|
| | pattern = re.compile(r"(<\|video\|>|<\|image\|>|<\|audio\|>)") |
| |
|
| | all_tags_in_order = [ |
| | match.group() for match in pattern.finditer(turn["content"]) |
| | ] |
| | n_vids = sum(1 for tag in all_tags_in_order if tag == "<|video|>") |
| | n_audios = sum(1 for tag in all_tags_in_order if tag == "<|audio|>") |
| |
|
| | assert ( |
| | len(turn.get("image_urls", [])) |
| | + len(turn.get("video_urls", [])) |
| | + len(turn.get("audio_urls", [])) |
| | ) == len( |
| | all_tags_in_order |
| | ), f"Number of media URLs does not match number of media tags." |
| |
|
| | if mm_tokens is None: |
| | mm_tokens = [ |
| | cls.audio_pad if tag == "<|audio|>" else cls.image_pad |
| | for tag in all_tags_in_order |
| | ] |
| |
|
| | assert len(mm_tokens) == len(all_tags_in_order) |
| |
|
| | if config.get("llava_pretrain", False): |
| | mm_str = "".join([mm_tokens[i] for i in range(len(all_tags_in_order))]) |
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| | return output |
| |
|
| | if query_template: |
| | processed_content = seed.choice(query_template).format(turn["content"]) |
| |
|
| | tags_after_template = pattern.findall(processed_content) |
| | if len(all_tags_in_order) != len(tags_after_template): |
| | cleaned_template_text = pattern.sub("", processed_content) |
| | processed_content = "".join(all_tags_in_order) + cleaned_template_text |
| | turn["content"] = processed_content |
| |
|
| | content_parts = pattern.split(turn["content"].strip()) |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
| | if getattr(output, "input_ids", None) is not None: |
| | role_encoded = tokenizer.encode( |
| | f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
| | ) |
| | output.input_ids += role_encoded |
| | if turn.get("trainable_role", False): |
| | output.label_ids += role_encoded |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
| |
|
| | tag_cursor = 0 |
| |
|
| | for part in content_parts: |
| | part = part.strip() |
| |
|
| | if not part: |
| | continue |
| |
|
| | if part not in ["<|image|>", "<|video|>", "<|audio|>"]: |
| | content_text = part |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += "\n" + content_text |
| | if getattr(output, "input_ids", None) is not None: |
| | content_encoded = tokenizer.encode( |
| | "\n" + content_text, truncation=False |
| | ) |
| | output.input_ids += content_encoded |
| | if turn.get("trainable_content", False): |
| | output.label_ids += content_encoded |
| | else: |
| | output.label_ids += [ |
| | IGNORE_INDEX for _ in range(len(content_encoded)) |
| | ] |
| | continue |
| |
|
| | if part == "<|image|>": |
| | mime = Preprocessor.prompt_mime( |
| | mimes=mimes if not file_names else None, |
| | fixed_mime=fixed_mime if not file_names else False, |
| | file_name=file_names[tag_cursor] if file_names else None, |
| | tag_idx=output.sample_mm_counter["image"], |
| | is_video=False, |
| | is_audio=False, |
| | seed=seed, |
| | ) |
| | mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
| | discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
| | vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" |
| | mm_str = ( |
| | cls.new_line |
| | + mime_str |
| | + cls.new_line |
| | + discrete_image_str |
| | + cls.new_line |
| | + vector_str |
| | ) |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| |
|
| | output.sample_mm_counter["image"] += 1 |
| | tag_cursor += 1 |
| |
|
| | elif part == "<|video|>": |
| | mime = Preprocessor.prompt_mime( |
| | mimes=mimes if not file_names else None, |
| | fixed_mime=fixed_mime if not file_names else False, |
| | file_name=file_names[tag_cursor] if file_names else None, |
| | tag_idx=output.sample_mm_counter["video"], |
| | is_video=True, |
| | is_audio=False, |
| | seed=seed, |
| | ) |
| | mm_str = "" |
| | aux_inputs = { |
| | "video_duration": output.videos_duration[ |
| | output.sample_mm_counter["video"] |
| | ]["video_duration"], |
| | } |
| | mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
| | aux_str = f"{cls.aux_video_start}{cls.aux_vid_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_video_end}" |
| | vector_str = f"{cls.video_start}{cls.video_pad}{cls.video_end}" |
| | mm_str += ( |
| | cls.new_line |
| | + mime_str |
| | + cls.new_line |
| | + aux_str |
| | + cls.new_line |
| | + vector_str |
| | ) |
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| | output.sample_mm_counter["video"] += 1 |
| | tag_cursor += 1 |
| |
|
| | elif part == "<|audio|>": |
| | mime = Preprocessor.prompt_mime( |
| | mimes=mimes if not file_names else None, |
| | fixed_mime=fixed_mime if not file_names else False, |
| | file_name=file_names[tag_cursor] if file_names else None, |
| | tag_idx=output.sample_mm_counter["audio"], |
| | is_video=False, |
| | is_audio=True, |
| | seed=seed, |
| | ) |
| | mm_str = "" |
| | aux_inputs = { |
| | "audio_duration": output.audios_duration[ |
| | output.sample_mm_counter["audio"] |
| | ]["duration"], |
| | } |
| | mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
| | aux_str = f"{cls.aux_audio_start}{cls.aux_audio_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_audio_end}" |
| | discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" |
| | vector_str = f"{cls.audio_start}{cls.audio_pad}{cls.audio_end}" |
| | mm_str += ( |
| | cls.new_line |
| | + mime_str |
| | + cls.new_line |
| | + aux_str |
| | + cls.new_line |
| | + discrete_audio_str |
| | + cls.new_line |
| | + vector_str |
| | ) |
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| |
|
| | output.sample_mm_counter["audio"] += 1 |
| | tag_cursor += 1 |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += cls.turn_suffix |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| |
|
| | return output |
| |
|
| | @classmethod |
| | def prompt_assistant( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | role: Optional[str] = "assistant", |
| | content: Optional[str] = None, |
| | is_last_turn=False, |
| | is_eval=True, |
| | is_llava_pretrain=False, |
| | is_after_last_user_turn=False, |
| | ): |
| | assert content or turn |
| | if turn is None: |
| | turn = { |
| | "content": content, |
| | "role": role, |
| | } |
| |
|
| | if is_llava_pretrain: |
| | if hasattr(output, "input_str"): |
| | output.input_str += turn["content"] |
| | if getattr(output, "input_ids", None) is not None: |
| | content_encoded = tokenizer.encode(turn["content"], truncation=False) |
| | output.input_ids += content_encoded |
| | output.label_ids += content_encoded |
| | return output |
| |
|
| | reasoning_content = turn.get("reasoning_content", "") |
| | if ( |
| | not reasoning_content |
| | and isinstance(turn["content"], str) |
| | and "</think>" in turn["content"] |
| | ): |
| | parts = turn["content"].split("</think>", 1) |
| | reasoning_content = parts[0].split("<think>", 1)[-1].lstrip("\n") |
| | turn["content"] = parts[1].lstrip("\n") |
| |
|
| | if is_after_last_user_turn and (is_last_turn or reasoning_content): |
| | content_to_strip = turn.get("content") or "" |
| | stripped_content = content_to_strip.lstrip("\n") |
| |
|
| | if reasoning_content is None: |
| | reasoning_content = "" |
| | turn["content"] = ( |
| | f"<think>\n{reasoning_content.strip()}\n</think>\n\n{stripped_content}" |
| | ) |
| |
|
| | if turn.get("tool_calls"): |
| | for tool_call in turn["tool_calls"]: |
| | func_name = tool_call.get("function", {}).get("name", "") |
| | args = tool_call.get("function", {}).get("arguments", {}) |
| |
|
| | if isinstance(args, str): |
| | try: |
| | args = json.loads(args) |
| | except Exception: |
| | pass |
| | if not isinstance(args, dict): |
| | print( |
| | f"[error] tool_call.function.arguments가 dict이 아님: type={type(args)}, value={str(args)}" |
| | ) |
| | assert ( |
| | False |
| | ), "tool_call.function.arguments는 dict이거나 dict를 나타내는 JSON 문자열이어야 합니다." |
| |
|
| | tool_turn_content = f"\n<tool_call>{func_name}\n" |
| |
|
| | for key, value in args.items(): |
| | arg_value = ( |
| | json.dumps(value, ensure_ascii=False) |
| | if not isinstance(value, str) |
| | else value |
| | ) |
| | tool_turn_content += f"<arg_key>{key}</arg_key>\n<arg_value>{arg_value}</arg_value>\n" |
| | tool_turn_content += "</tool_call>" |
| |
|
| | if func_name == "t2i_model_generation": |
| | assert ( |
| | "<|t2i_model_generation_target_discrete_image|>" |
| | in turn["content"] |
| | ), "t2i_model_generation tool call must have target discrete image tag in content." |
| | turn["content"] = turn["content"].replace( |
| | "<|t2i_model_generation_target_discrete_image|>", |
| | tool_turn_content, |
| | ) |
| | else: |
| | turn["content"] += tool_turn_content |
| |
|
| | pattern = re.compile( |
| | r"(<\|image\|>|<\|discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>)" |
| | ) |
| | all_tags_in_order = [ |
| | match.group() for match in pattern.finditer(turn["content"]) |
| | ] |
| |
|
| | assert ( |
| | len(turn.get("image_urls", [])) |
| | + len(turn.get("video_urls", [])) |
| | + len(turn.get("audio_urls", [])) |
| | ) == len( |
| | all_tags_in_order |
| | ), f"Number of media URLs does not match number of media tags." |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
| | if is_eval and is_last_turn: |
| | if reasoning_content.strip() == "": |
| | output.input_str += f"<think>\n\n</think>\n\n" |
| | turn["content"] = stripped_content |
| | else: |
| | output.input_str += f"{turn['content']}{cls.turn_suffix}" |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | role_encoded = tokenizer.encode( |
| | f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
| | ) |
| | output.input_ids += role_encoded |
| |
|
| | if is_eval and is_last_turn: |
| | if reasoning_content.strip() == "": |
| | output.input_ids += tokenizer.encode( |
| | f"<think>\n\n</think>\n\n", truncation=False |
| | ) |
| | turn["content"] = stripped_content |
| | else: |
| | if turn.get("trainable_role", True): |
| | output.label_ids += role_encoded |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
| |
|
| | turn_img_idx = 0 |
| | content_parts = pattern.split(turn["content"].strip()) |
| | for part in content_parts: |
| | part = part.strip() |
| |
|
| | if not part: |
| | continue |
| |
|
| | if part not in [ |
| | "<|image|>", |
| | "<|discrete_image|>", |
| | "<|audio|>", |
| | "<|discrete_audio|>", |
| | ]: |
| | content_text = part |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += "\n" + content_text |
| | if getattr(output, "input_ids", None) is not None: |
| | content_encoded = tokenizer.encode( |
| | "\n" + content_text, truncation=False |
| | ) |
| | output.input_ids += content_encoded |
| | if turn.get("trainable_content", True): |
| | output.label_ids += content_encoded |
| | else: |
| | output.label_ids += [ |
| | IGNORE_INDEX for _ in range(len(content_encoded)) |
| | ] |
| | continue |
| |
|
| | if part == "<|image|>": |
| | file_name = turn.get("image_urls", [])[turn_img_idx] |
| | if isinstance(file_name, str) and "#" in file_name: |
| | file_name = file_name.split("#")[-1] |
| | file_name = os.path.basename(file_name) |
| | mime = Preprocessor.prompt_mime( |
| | mimes=None, |
| | fixed_mime=False, |
| | file_name=file_name, |
| | tag_idx=output.sample_mm_counter["image"], |
| | is_video=False, |
| | is_audio=False, |
| | seed=None, |
| | ) |
| | mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
| | discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
| | vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" |
| | mm_str = ( |
| | cls.new_line |
| | + mime_str |
| | + cls.new_line |
| | + discrete_image_str |
| | + cls.new_line |
| | + vector_str |
| | ) |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += [ |
| | IGNORE_INDEX for _ in range(len(token_ids)) |
| | ] |
| | turn_img_idx += 1 |
| | output.sample_mm_counter["image"] += 1 |
| |
|
| | elif part == "<|discrete_image|>": |
| | discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
| | mm_str = cls.new_line + discrete_image_str |
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | output.label_ids += token_ids |
| | turn_img_idx += 1 |
| |
|
| | elif part == "<|discrete_audio|>": |
| | discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" |
| | mm_str = cls.new_line + discrete_audio_str |
| | if hasattr(output, "input_str"): |
| | output.input_str += mm_str |
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(mm_str, truncation=False) |
| | output.input_ids += token_ids |
| | if turn.get("trainable_content", True): |
| | output.label_ids += token_ids |
| | else: |
| | output.label_ids += [ |
| | IGNORE_INDEX for _ in range(len(token_ids)) |
| | ] |
| |
|
| | elif part == "<|audio|>": |
| | raise Exception( |
| | "Assistant turn에서 <|audio|> 태그는 지원하지 않음. discrete_audio 만 지원함." |
| | ) |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += cls.turn_suffix |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) |
| | output.input_ids += token_ids |
| | if turn.get("trainable_content", True): |
| | output.label_ids += token_ids |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
| |
|
| | return output |
| |
|
| | @classmethod |
| | def prompt_tool( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | role: Optional[str] = None, |
| | content: Optional[str] = None, |
| | eot: Optional[bool] = None, |
| | need_start_tag=True, |
| | need_end_tag=True, |
| | ): |
| | assert (content and role) or turn |
| | if turn is None: |
| | turn = { |
| | "content": content, |
| | "role": role, |
| | "endofturn": eot, |
| | } |
| | assert ( |
| | "tool" == turn["role"] |
| | ), f'[warning] unexpected turn["role"]: {turn["role"]}' |
| | content_value = turn.get("content", "") |
| |
|
| | if isinstance(content_value, dict): |
| | if "response" in content_value: |
| | content_str = content_value["response"] |
| | else: |
| | content_str = json.dumps(content_value, ensure_ascii=False) |
| | elif isinstance(content_value, str): |
| | try: |
| | parsed = json.loads(content_value) |
| | if isinstance(parsed, dict): |
| | if "response" in parsed: |
| | content_str = parsed["response"] |
| | else: |
| | content_str = json.dumps(parsed, ensure_ascii=False) |
| | else: |
| | content_str = content_value |
| | except (json.JSONDecodeError, TypeError): |
| | content_str = content_value |
| | else: |
| | content_str = str(content_value) |
| |
|
| | turn["content"] = ( |
| | f"<tool_response>{turn.get('name', '')}\n{content_str}\n</tool_response>" |
| | ) |
| |
|
| | if hasattr(output, "input_str"): |
| | if need_start_tag: |
| | output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
| | output.input_str += f"{cls.new_line}{turn['content']}" |
| | if need_end_tag: |
| | output.input_str += cls.turn_suffix |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | if need_start_tag: |
| | role_encoded = tokenizer.encode( |
| | f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
| | ) |
| | output.input_ids += role_encoded |
| |
|
| | if turn.get("trainable_role", True): |
| | output.label_ids += role_encoded |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
| |
|
| | content = f"{cls.new_line}{turn['content']}" |
| | content_encoded = tokenizer.encode(content, truncation=False) |
| | if need_end_tag: |
| | content_encoded += tokenizer.encode( |
| | f"{cls.turn_suffix}", truncation=False |
| | ) |
| | output.input_ids += content_encoded |
| | if turn.get("trainable_content", True): |
| | output.label_ids += content_encoded |
| | else: |
| | output.label_ids += [ |
| | IGNORE_INDEX for _ in range(len(content_encoded)) |
| | ] |
| | return output |
| |
|
| | @classmethod |
| | def prompt_etc( |
| | cls, |
| | output, |
| | tokenizer=None, |
| | turn: Optional[dict] = None, |
| | role: Optional[str] = None, |
| | content: Optional[str] = None, |
| | eot: Optional[bool] = None, |
| | ): |
| | assert (content and role) or turn |
| | if turn is None: |
| | turn = { |
| | "content": content, |
| | "role": role, |
| | "endofturn": eot, |
| | } |
| | print(f'[warning] unexpected turn["role"]: {turn["role"]}') |
| |
|
| | if hasattr(output, "input_str"): |
| | output.input_str += f"{cls.turn_prefix}{turn['role']}\n" |
| | output.input_str += f"{turn['content']}{cls.turn_suffix}" |
| | if turn.get("stop", False): |
| | output.input_str += cls.stop_token |
| | if turn.get("endofturn", False): |
| | output.input_str += cls.eot |
| |
|
| | if getattr(output, "input_ids", None) is not None: |
| | role_encoded = tokenizer.encode( |
| | f"{cls.turn_prefix}{turn['role']}\n", truncation=False |
| | ) |
| | output.input_ids += role_encoded |
| |
|
| | if turn.get("trainable_role", True): |
| | output.label_ids += role_encoded |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
| |
|
| | content = f"{turn['content']}{cls.turn_suffix}" |
| | if turn.get("stop", False): |
| | content += cls.stop_token |
| | if turn.get("endofturn", False): |
| | content += cls.eot |
| | content_encoded = tokenizer.encode(content, truncation=False) |
| | output.input_ids += content_encoded |
| | if turn.get("trainable_content", True): |
| | output.label_ids += content_encoded |
| | else: |
| | output.label_ids += [IGNORE_INDEX for _ in range(len(content_encoded))] |
| | return output |
| |
|
| | def __call__(self, sample): |
| | return self.preprocess_new(sample) |
| |
|
| | @classmethod |
| | def batchify( |
| | cls, |
| | items: List[Dict[str, Any],], |
| | device: str = None, |
| | ): |
| | batch = dict() |
| | for item in items: |
| | for k, v in item.items(): |
| | if isinstance(v, torch.Tensor): |
| | if device is not None: |
| | v = v.to(device=device) |
| | elif k == "pixel_values": |
| | v = [_v.to(device=device) for _v in v] |
| |
|
| | if k not in batch: |
| | batch[k] = [ |
| | v, |
| | ] |
| | else: |
| | batch[k].append(v) |
| |
|
| | for k, v in batch.items(): |
| | if isinstance(v[0], torch.Tensor): |
| | if k in ["image_grid_thw", "video_grid_thw"]: |
| | batch[k] = torch.cat(v, dim=0) |
| | continue |
| | batch[k] = torch.stack(v, dim=0) |
| | batch["video_grid_thw"] = None |
| | batch["pixel_values_videos"] = None |
| | return batch |
| |
|
| | def convert_wds_to_datalake( |
| | self, |
| | img: Union[PIL.Image.Image, Dict[str, PIL.Image.Image]] = {}, |
| | json: Dict[str, Any] = {}, |
| | benchmark: Optional[str] = None, |
| | video: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, |
| | audio: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, |
| | ): |
| |
|
| | if "lines" in json: |
| | del json["lines"] |
| | if "paragraphs" in json: |
| | del json["paragraphs"] |
| |
|
| | assert json["meta"]["type"] in [ |
| | "caption", |
| | "vqa", |
| | "textread", |
| | ], f"{json['meta']['path']}, {json['meta']['type']}: The dataset type should be one of them: caption, vqa, textread." |
| |
|
| | sample = {"vlm": {}} |
| | sample["vlm"] = get_wds_default_config( |
| | json["meta"], existing_default_config=self.wds_default_config |
| | ) |
| | sample["vlm"]["data_name"] = json["meta"].get("name", "unk") |
| |
|
| | sample["vlm"]["data_type"] = ( |
| | "wds" |
| | if (isinstance(img, PIL.Image.Image) and img) |
| | or (isinstance(img, dict) and len(img) > 0) |
| | else "sft1" |
| | ) |
| |
|
| | sample["vlm"]["sample_id"] = json.get("qa_id", None) |
| | sample["vlm"]["category"] = json.get("category", None) |
| | sample["vlm"]["data_info"] = json.get("data_info", dict()) |
| | sample["vlm"]["options"] = None |
| | if "choices_en" in sample["vlm"]["data_info"]: |
| | if sample["vlm"]["options"] is None and json["meta"]["lang"] == "en": |
| | sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_en"] |
| | sample["vlm"]["options_en"] = sample["vlm"]["data_info"]["choices_en"] |
| | if "choices_ko" in sample["vlm"]["data_info"]: |
| | if sample["vlm"]["options"] is None and json["meta"]["lang"] == "ko": |
| | sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_ko"] |
| | sample["vlm"]["options_ko"] = sample["vlm"]["data_info"]["choices_ko"] |
| | sample["vlm"]["image_index"] = json.get( |
| | "image_index", json.get("img_url", None) |
| | ) |
| |
|
| | if sample["vlm"].get("video", False): |
| | is_multi_image_dataset = False |
| | else: |
| | is_multi_image_dataset, img, json = convert_format_for_multi_image( |
| | img, json |
| | ) |
| |
|
| | if json["meta"]["type"] == "textread": |
| | key = "words" |
| | elif json["meta"].get("subtask", "") == "region": |
| | key = f"regions_{json['meta']['lang']}" |
| | elif json["meta"]["type"] == "vqa": |
| | key = f"qa_pairs_{json['meta']['lang']}" |
| | elif json["meta"]["type"] == "caption": |
| | key = f"captions_{json['meta']['lang']}" |
| | else: |
| | raise ConditionalError( |
| | f"wrong task type in wds config: {sample['vlm']['data_name']}" |
| | ) |
| |
|
| | turns = [ |
| | { |
| | "role": "tool_list", |
| | "content": "", |
| | "content_type": "text", |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "candidates": [], |
| | "endofturn": False, |
| | }, |
| | { |
| | "role": "system", |
| | "content_type": "text", |
| | "candidates": [], |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "content": "", |
| | "endofturn": False, |
| | }, |
| | ] |
| |
|
| | if json["meta"].get("llava_pretrain", False): |
| | sample["vlm"]["llava_pretrain"] = True |
| |
|
| | use_task_prompt = json["meta"].get( |
| | "use_task_prompt", self.wds_default_config["use_task_prompt"] |
| | ) |
| | get_random = json["meta"].get( |
| | "get_random", self.wds_default_config["get_random"] |
| | ) |
| | reasoning = json["meta"].get("reasoning", self.wds_default_config["reasoning"]) |
| |
|
| | try: |
| | if key not in json: |
| | key = key[:-3] |
| | assert key in json |
| | if len(json[key]) == 0: |
| | key = key[:-3] |
| | assert key in json |
| | except: |
| | raise ConditionalError( |
| | f"{key} key is not in json? dataset name: {sample['vlm']['data_name']}" |
| | ) |
| |
|
| | first_turn = True |
| | if "region" in key: |
| | json[key] = json[key]["00"] |
| | sample["vlm"]["multiturn_n_samples"] = 1 |
| | if ( |
| | not is_multi_image_dataset |
| | and sample["vlm"]["multiturn_n_samples"] > 1 |
| | or "region" in key |
| | ): |
| | json[key] = sampling_multiturn_single_img( |
| | json[key], |
| | sample["vlm"]["multiturn_n_samples"], |
| | sample["vlm"]["multiturn_preserve_order"], |
| | sample["vlm"]["multiturn_continuous"], |
| | ) |
| |
|
| | if sample["vlm"].get("video", False): |
| | for qa in json[key]: |
| | vid_src = [] |
| | user = { |
| | "role": "user", |
| | "content_type": "text", |
| | "candidates": [], |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "image_urls": [], |
| | "image_metas": [], |
| | "video_urls": [], |
| | "video_metas": [], |
| | "audio_urls": [], |
| | "audio_metas": [], |
| | "content": "", |
| | "endofturn": False, |
| | } |
| |
|
| | instruct_prompt, task_prompt = hcx_vision_prompter( |
| | task=json["meta"]["type"], |
| | subtask=json["meta"].get("subtask", None), |
| | lang=json["meta"]["lang"], |
| | get_random=get_random, |
| | use_task_prompt=use_task_prompt, |
| | ) |
| |
|
| | prompt = qa[0] |
| | answer = qa[-1] if reasoning else qa[1] |
| |
|
| | if first_turn: |
| | user["video_metas"].append({"lens": []}) |
| | user["content"] += "<|video|>" |
| | prompt = task_prompt.format(prompt) |
| |
|
| | if "entities" in json: |
| | user["video_metas"][0]["lens"] = json["entities"].get("00", []) |
| | if isinstance(video, dict): |
| | vid_src.append(video["00"]) |
| | else: |
| | vid_src.append(video) |
| | first_turn = False |
| |
|
| | user["video_urls"] = vid_src |
| | user["content"] += prompt |
| |
|
| | assistant = { |
| | "candidates": [], |
| | "content": answer, |
| | "content_type": "text", |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "role": "assistant", |
| | "trainable_content": True, |
| | "trainable_role": True, |
| | "stop": False, |
| | "endofturn": True, |
| | } |
| | turns.append(user) |
| | turns.append(assistant) |
| |
|
| | else: |
| | if key.startswith("qa_pairs") or key.startswith("captions"): |
| | if self.mode != "train" and key.startswith("qa_pairs"): |
| | qas = dict() |
| | for qa in json[key]: |
| | q = qa[0] |
| | if q not in qas: |
| | qas[q] = list() |
| | for _i, _e in enumerate(qa[1:]): |
| | if len(qas[q]) <= _i: |
| | qas[q].append(list()) |
| | qas[q][_i].append(_e) |
| | json[key] = [ |
| | [ |
| | k, |
| | ] |
| | + v |
| | for k, v in qas.items() |
| | ] |
| |
|
| | if self.mode != "train": |
| | json[key] = json[key][:1] |
| |
|
| | for qa in json[key]: |
| | img_src = [] |
| | user = { |
| | "role": "user", |
| | "content_type": "text", |
| | "candidates": [], |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "image_urls": [], |
| | "image_metas": [], |
| | "video_urls": [], |
| | "video_metas": [], |
| | "audio_urls": [], |
| | "audio_metas": [], |
| | "content": "", |
| | "endofturn": False, |
| | } |
| | img_keys = re.findall(r"<image_(\d+)>", qa[0]) |
| | video_keys = re.findall(r"<video_(\d+)>", qa[0]) |
| | audio_keys = re.findall(r"<audio_(\d+)>", qa[0]) |
| |
|
| | if key.startswith("qa_pairs"): |
| | if len(qa) > 2: |
| | sample_id = qa[2] |
| | if ( |
| | isinstance(sample_id, (list, tuple)) |
| | and len(sample_id) > 0 |
| | ): |
| | sample_id = sample_id[0] |
| | sample["vlm"]["sample_id"] = sample_id |
| |
|
| | instruct_prompt, task_prompt = hcx_vision_prompter( |
| | task=json["meta"]["type"], |
| | subtask=json["meta"].get("subtask", None), |
| | lang=json["meta"]["lang"], |
| | get_random=get_random, |
| | use_task_prompt=use_task_prompt, |
| | ) |
| | if json["meta"]["type"] == "vqa": |
| | prompt = qa[0] |
| | answer = qa[-1] if reasoning else qa[1] |
| | elif json["meta"]["type"] == "caption": |
| | prompt = task_prompt.format("") |
| | answer = qa |
| |
|
| | if first_turn or self.mode != "train": |
| | if json["meta"]["type"] == "vqa": |
| | prompt = task_prompt.format(prompt) |
| | if first_turn and not is_multi_image_dataset: |
| | user["image_metas"].append({"words": [], "lens": []}) |
| | if "<image_00>" in prompt: |
| | prompt = prompt.replace("<image_00>", "<|image|>") |
| | else: |
| | user["content"] += "<|image|>" |
| | user["image_metas"][0]["words"] = json.get("words", {}).get( |
| | "00", [] |
| | ) |
| | if "objects" in json: |
| | user["image_metas"][0]["lens"] = json["objects"].get( |
| | "00", [] |
| | ) |
| | elif "entities" in json: |
| | user["image_metas"][0]["lens"] = json["entities"].get( |
| | "00", [] |
| | ) |
| | if isinstance(img, dict): |
| | img_src.append(img["00"]) |
| | else: |
| | img_src.append(img) |
| | elif len(img_keys) > 0: |
| | for i, key in enumerate(img_keys): |
| | user["image_metas"].append({"words": [], "lens": []}) |
| | if f"<image_{i:02d}>" in prompt: |
| | prompt = prompt.replace(f"<image_{i:02d}>", "<|image|>") |
| | else: |
| | user["content"] += "<|image|>" |
| | img_src.append(img[key]) |
| | _words = json.get("words", {}) |
| | if isinstance(_words, dict): |
| | _words = _words.get(key, []) |
| | user["image_metas"][i]["words"] = _words |
| | if "objects" in json: |
| | _objects = json["objects"].get(key, []) |
| | if isinstance(_objects, dict): |
| | _objects = _objects.get(key, []) |
| | user["image_metas"][i]["lens"] = _objects |
| | if "entities" in json: |
| | _entities = json["entities"].get(key, []) |
| | if isinstance(_entities, dict): |
| | _entities = _entities.get(key, []) |
| | user["image_metas"][i]["lens"] = _entities |
| | user["image_urls"] = img_src |
| |
|
| | if len(audio_keys) > 0: |
| | for i, key in enumerate(audio_keys): |
| | if isinstance(audio, dict): |
| | user["audio_urls"].append(audio[key]) |
| | else: |
| | user["audio_urls"].append(audio) |
| | user["audio_metas"].append( |
| | { |
| | "format": "wav", |
| | "note": "This audio sample is passed to convert_wds_to_datalake function.", |
| | } |
| | ) |
| | if f"<audio_{i:02d}>" in prompt: |
| | prompt = prompt.replace(f"<audio_{i:02d}>", "<|audio|>") |
| | else: |
| | user["content"] += "<|audio|>" |
| |
|
| | user["content"] += prompt |
| |
|
| | content, candidates = None, list() |
| | if self.mode != "train": |
| | if isinstance(answer, (int, float)): |
| | pass |
| | elif isinstance(answer, str): |
| | if answer != "None": |
| | try: |
| | answer = ast.literal_eval(answer) |
| | except Exception as ex: |
| | pass |
| | if not isinstance(answer, (list, tuple)): |
| | answer = [ |
| | answer, |
| | ] |
| | candidates += answer[1:] |
| | answer = answer[0] |
| | content = answer |
| | elif isinstance(answer, (list, tuple)): |
| | for _idx, _answer in enumerate(answer): |
| | if isinstance(_answer, str): |
| | if isinstance(benchmark, str) and benchmark in [ |
| | "textvqa", |
| | ]: |
| | try: |
| | _answer = ast.literal_eval(_answer) |
| | except Exception as ex: |
| | pass |
| | if isinstance(_answer, dict): |
| | _answer = str(_answer) |
| | if not isinstance(_answer, (list, tuple)): |
| | _answer = [ |
| | _answer, |
| | ] |
| | if _idx == 0: |
| | content = _answer[0] |
| | candidates += _answer[1:] |
| | else: |
| | candidates += _answer |
| |
|
| | if isinstance(content, (int, float)): |
| | content = str(content) |
| | assert content is None or isinstance(content, str) |
| | for _idx, _candidate in enumerate(candidates): |
| | if isinstance(_candidate, (int, float)): |
| | candidates[_idx] = str(_candidate) |
| | assert isinstance(candidates[_idx], str) |
| | mcqa_gt = sample["vlm"]["data_info"].get("choice_answer", None) |
| | if isinstance(mcqa_gt, str): |
| | content = mcqa_gt |
| |
|
| | assistant = { |
| | "candidates": candidates, |
| | "content": answer if self.mode == "train" else content, |
| | "content_type": "text", |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "role": "assistant", |
| | "trainable_content": True, |
| | "trainable_role": True, |
| | "stop": False, |
| | "endofturn": True, |
| | } |
| | turns.append(user) |
| | turns.append(assistant) |
| |
|
| | elif key == "words": |
| | img_src = [] |
| | user = { |
| | "role": "user", |
| | "content_type": "text", |
| | "candidates": [], |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "image_urls": [], |
| | "image_metas": [], |
| | "video_urls": [], |
| | "video_metas": [], |
| | "audio_urls": [], |
| | "audio_metas": [], |
| | "content": "<|image|>", |
| | "endofturn": False, |
| | } |
| | instruct_prompt, task_prompt = hcx_vision_prompter( |
| | task=json["meta"]["type"], |
| | subtask=json["meta"].get("subtask", None), |
| | lang=json["meta"]["lang"], |
| | get_random=get_random, |
| | use_task_prompt=use_task_prompt, |
| | ) |
| | user["content"] += task_prompt |
| | user["image_metas"].append({"words": [], "lens": []}) |
| | user["image_metas"][0]["words"] = json["words"]["00"] |
| | if "entities" in json: |
| | user["image_metas"][0]["lens"] = json["entities"].get("00", []) |
| | img_src.append(img["00"]) |
| | user["image_urls"] = img_src |
| |
|
| | words_list = [ |
| | d["text"].strip() for d in json["words"]["00"] if d["text"] |
| | ] |
| | gt = " ".join(words_list) |
| | assistant = { |
| | "candidates": [], |
| | "content": gt, |
| | "content_type": "text", |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "role": "assistant", |
| | "trainable_content": True, |
| | "trainable_role": True, |
| | "stop": False, |
| | "endofturn": True, |
| | } |
| | turns.append(user) |
| | turns.append(assistant) |
| |
|
| | elif key.startswith("regions"): |
| | for region in json[key]: |
| | img_src = [] |
| | user = { |
| | "role": "user", |
| | "content_type": "text", |
| | "candidates": [], |
| | "trainable_role": False, |
| | "trainable_content": False, |
| | "stop": False, |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "image_urls": [], |
| | "image_metas": [], |
| | "video_urls": [], |
| | "video_metas": [], |
| | "audio_urls": [], |
| | "audio_metas": [], |
| | "content": "<|image|><|region|>", |
| | "endofturn": False, |
| | } |
| | instruct_prompt, task_prompt = hcx_vision_prompter( |
| | task=json["meta"]["type"], |
| | subtask=json["meta"].get("subtask", None), |
| | lang=json["meta"]["lang"], |
| | get_random=get_random, |
| | use_task_prompt=use_task_prompt, |
| | ) |
| | sample["vlm"]["query_template"] = [task_prompt] |
| | user["image_metas"].append({"words": [], "lens": []}) |
| | user["image_metas"][0]["region"] = region |
| | if "words" in json: |
| | user["image_metas"][0]["words"] = json["words"].get("00", []) |
| | if "objects" in json: |
| | user["image_metas"][0]["lens"] = json["objects"].get("00", []) |
| | if "entities" in json: |
| | user["image_metas"][0]["lens"] = json["entities"].get("00", []) |
| | img_src.append(img["00"]) |
| | user["image_urls"] = img_src |
| |
|
| | assistant = { |
| | "candidates": [], |
| | "content": region["text"], |
| | "content_type": "text", |
| | "debuggingInfo": {}, |
| | "meta": {}, |
| | "role": "assistant", |
| | "trainable_content": True, |
| | "trainable_role": True, |
| | "stop": False, |
| | "endofturn": True, |
| | } |
| | turns.append(user) |
| | turns.append(assistant) |
| | else: |
| | raise ConditionalError( |
| | f"wrong task type in wds config: {sample['vlm']['data_name']}" |
| | ) |
| | sample["data"] = turns |
| | return sample |
| |
|
| | def preprocess_new(self, sample): |
| |
|
| | config = sample.get("vlm", {}) |
| | if config["data_type"] in ["sft1", "datalake"]: |
| | default_config = copy.deepcopy(self.default_config) |
| | default_config.update(config) |
| | config = default_config |
| | idx_for_debug = sample.get("idx", -1) |
| | turns = sample["data"] if "data" in sample else sample["messages"] |
| |
|
| | if self.random_system_prompt and self.rng.random() < config.get( |
| | "random_system_prob", 0.0 |
| | ): |
| | for turn in turns: |
| | if turn["role"] == "system": |
| | turn["content"] = self.random_system_prompt |
| | break |
| |
|
| | if sample.get("tools", None) is None: |
| | sample["tools"] = [] |
| |
|
| | if len(sample["tools"]) == 0: |
| | if ( |
| | self.rng.random() < config.get("random_tool_prob", 0.005) |
| | and len(self.common_tools) > 0 |
| | ): |
| |
|
| | max_n_tools = min(7, len(self.common_tools)) |
| | tool_counts = np.arange(1, max_n_tools + 1) |
| | tool_count_weights = 1.0 / tool_counts |
| | tool_count_weights = tool_count_weights / tool_count_weights.sum() |
| | n_tools = int(self.rng.choice(tool_counts, p=tool_count_weights)) |
| |
|
| | idxs = np.arange(len(self.common_tools)) |
| | weights = 1.0 / (idxs + 1) |
| | weights[0] += 1.0 |
| | weights = weights / weights.sum() |
| |
|
| | chosen_indices = self.rng.choice( |
| | len(self.common_tools), size=n_tools, replace=False, p=weights |
| | ) |
| |
|
| | self.rng.shuffle(chosen_indices) |
| |
|
| | sample["tools"] = [self.common_tools[i] for i in chosen_indices] |
| |
|
| | if "tools" in sample and sample["tools"]: |
| | tool_prompt = [] |
| | tool_prompt.append("# Tools\n\n") |
| | tool_prompt.append( |
| | "You may call one or more functions to assist with the user query.\n\n" |
| | ) |
| | tool_prompt.append( |
| | "You are provided with function signatures within <tools></tools> XML tags:\n" |
| | ) |
| | tool_prompt.append("<tools>\n") |
| | for tool in sample["tools"]: |
| | tool_prompt.append(json.dumps(tool, ensure_ascii=False)) |
| | tool_prompt.append("\n</tools>\n\n") |
| | tool_prompt.append( |
| | "For each function call, output the function name and arguments within the following XML format:\n" |
| | ) |
| | tool_prompt.append("<tool_call>{function-name}\n") |
| | tool_prompt.append("<arg_key>{arg-key-1}</arg_key>\n") |
| | tool_prompt.append("<arg_value>{arg-value-1}</arg_value>\n") |
| | tool_prompt.append("<arg_key>{arg-key-2}</arg_key>\n") |
| | tool_prompt.append("<arg_value>{arg-value-2}</arg_value>\n") |
| | tool_prompt.append("...\n") |
| | tool_prompt.append("</tool_call>") |
| |
|
| | tool_prompt = "".join(tool_prompt) |
| | else: |
| | tool_prompt = "" |
| |
|
| | multiturn_n_sample = config.get("multiturn_n_samples", 0) |
| | if multiturn_n_sample > 0 and self.mode == "train": |
| | turns = self._sampling_multiturn( |
| | turns, |
| | multiturn_n_sample, |
| | multiturn_preserve_order=config.get("multiturn_preserve_order", True), |
| | multiturn_continuous=config.get("multiturn_continuous", False), |
| | ) |
| |
|
| | for i, turn in enumerate(turns): |
| | if turn["role"] == "user": |
| | if "img_src" in turn: |
| | turns[i]["image_urls"] = turn["img_src"] |
| | turns[i]["image_metas"] = turn["meta"] |
| | for j, turn_img_meta in enumerate(turns[i]["image_metas"]): |
| | if "entities" in turn_img_meta: |
| | turns[i]["image_metas"][j]["lens"] = turn_img_meta[ |
| | "entities" |
| | ] |
| | turns[i]["meta"] = {} |
| |
|
| | max_image_cnt = config.get("max_image_cnt", 20) |
| | if max_image_cnt > 0 and config["data_type"] != "sft1": |
| | n_imgs = {} |
| | for i, turn in enumerate(turns): |
| | if turn["role"] == "user": |
| | n_imgs[i] = len(turn.get("image_urls", [])) |
| | assert ( |
| | n_imgs[i] <= max_image_cnt |
| | ), "skip sample if image_nums exceeds max_image_count per turn" |
| |
|
| | if sum(n_imgs.values()) > max_image_cnt: |
| | img_count = 0 |
| | for k, v in reversed(list(n_imgs.items())): |
| | img_count += v |
| | if img_count > max_image_cnt: |
| | break |
| |
|
| | img_count = sum(n_imgs.values()) - max_image_cnt |
| |
|
| | for i in range(k + 1): |
| | if turns[i]["role"] == "user": |
| | turns[i]["content"], n_removed1 = re.subn( |
| | r"<image_\d{2}>", |
| | "", |
| | turns[i]["content"].strip(), |
| | count=img_count, |
| | ) |
| | img_count -= n_removed1 |
| | turns[i]["content"], n_removed2 = re.subn( |
| | r"<\|image\|>", |
| | "", |
| | turns[i]["content"].strip(), |
| | count=img_count, |
| | ) |
| | img_count -= n_removed2 |
| | n_removed_imgs = n_removed1 + n_removed2 |
| | turns[i]["image_urls"] = turns[i]["image_urls"][n_removed_imgs:] |
| |
|
| | if n_removed_imgs > 0 and len(turns[i]["image_urls"]) == 0: |
| | idx = i |
| | while True: |
| | idx += 1 |
| | turns[idx]["trainable_role"] = False |
| | turns[idx]["trainable_content"] = False |
| | if turns[idx]["role"] == "assistant": |
| | break |
| |
|
| | n_imgs_after = {} |
| | for i, turn in enumerate(turns): |
| | if turn["role"] == "user": |
| | n_imgs_after[i] = len(turn.get("image_urls", [])) |
| | assert sum(n_imgs_after.values()) > 0, "The n_imgs of vlm data is zero." |
| |
|
| | n_mm_after = {} |
| | for i, turn in enumerate(turns): |
| | if turn["role"] == "user" or turn["role"] == "assistant": |
| | n_mm_after[i] = ( |
| | len(turn.get("image_urls", [])) |
| | + len(turn.get("video_urls", [])) |
| | + len(turn.get("audio_urls", [])) |
| | ) |
| | assert sum(n_mm_after.values()) > 0, "The n_mm of omni data is zero." |
| |
|
| | queries, gts = list(), list() |
| | output = Processed_sample( |
| | input_str="", |
| | input_ids=[], |
| | label_ids=[], |
| | imgs=[], |
| | discrete_imgs=[], |
| | videos=[], |
| | videos_duration=[], |
| | video_audios=[], |
| | audios=[], |
| | audios_duration=[], |
| | discrete_audios=[], |
| | sample_mm_counter={ |
| | "image": 0, |
| | "video": 0, |
| | "audio": 0, |
| | }, |
| | ) |
| | system_role_count = 0 |
| | last_user_idx = max( |
| | (i for i, d in enumerate(turns) if d.get("role") == "user"), default=-1 |
| | ) |
| | for i, turn in enumerate(turns): |
| | if turn["role"] == "tool_list": |
| | continue |
| |
|
| | elif turn["role"] == "system": |
| | if config.get("llava_pretrain", False): |
| | continue |
| | output = Preprocessor.prompt_system( |
| | turn=turn, |
| | output=output, |
| | tokenizer=self.tokenizer, |
| | seed=self.rng, |
| | tool_prompt=tool_prompt, |
| | system_role_count=system_role_count, |
| | ) |
| | system_role_count += 1 |
| |
|
| | elif turn["role"].startswith("user"): |
| | output = Preprocessor.load_mm( |
| | output=output, |
| | img_dir=config.get("img_dir", ""), |
| | turn=turn, |
| | prepare_input_fn=self.prepare_input_fn, |
| | max_image_cnt=max_image_cnt, |
| | video_max_num_frames=self.video_max_num_frames, |
| | video_max_pixels=self.video_max_pixels, |
| | use_audio=self.train_audio, |
| | ) |
| | output = Preprocessor.prompt_user( |
| | output=output, |
| | tokenizer=self.tokenizer, |
| | turn=turn, |
| | is_train=True if self.mode == "train" else False, |
| | fixed_mime=config.get("fixed_mime", False), |
| | mimes=self.mimes, |
| | query_template=config.get("query_template", None), |
| | config=config, |
| | seed=self.rng, |
| | ) |
| |
|
| | queries.append(turn["content"].replace("<|image|>", "").strip()) |
| | elif turn["role"].startswith("assistant"): |
| | output = Preprocessor.load_mm( |
| | output=output, |
| | img_dir=config.get("img_dir", ""), |
| | turn=turn, |
| | prepare_input_fn=self.prepare_input_fn, |
| | max_image_cnt=max_image_cnt, |
| | video_max_num_frames=self.video_max_num_frames, |
| | video_max_pixels=self.video_max_pixels, |
| | use_audio=self.train_audio, |
| | ) |
| |
|
| | is_after_last_user = i > last_user_idx |
| | is_first_assistant_after_last_user = False |
| | if is_after_last_user: |
| | is_first_assistant_after_last_user = all( |
| | turns[j]["role"] != "assistant" |
| | for j in range(last_user_idx + 1, i) |
| | ) |
| |
|
| | output = Preprocessor.prompt_assistant( |
| | output=output, |
| | tokenizer=self.tokenizer, |
| | turn=turn, |
| | is_last_turn=is_first_assistant_after_last_user, |
| | is_eval=True if self.mode != "train" else False, |
| | is_llava_pretrain=config.get("llava_pretrain", False), |
| | is_after_last_user_turn=is_after_last_user, |
| | ) |
| | _gts = turn["content"] |
| | if isinstance(_gts, str): |
| | _gts = [ |
| | _gts, |
| | ] |
| | if "candidates" in turn and len(turn["candidates"]) > 0: |
| | for _candidates in turn["candidates"]: |
| | if isinstance(_candidates, str): |
| | _gts += [ |
| | _candidates, |
| | ] |
| | elif isinstance(turn["candidates"][0], (list, tuple)): |
| | _gts += _candidates |
| | gts.append(_gts) |
| | elif turn["role"] == "tool": |
| | if config.get("llava_pretrain", False): |
| | continue |
| |
|
| | output = Preprocessor.prompt_tool( |
| | output=output, |
| | tokenizer=self.tokenizer, |
| | turn=turn, |
| | need_start_tag=( |
| | True |
| | if (i == 0 or turns[i - 1].get("role") != "tool") |
| | else False |
| | ), |
| | need_end_tag=( |
| | True |
| | if (i == (len(turns) - 1) or turns[i + 1].get("role") != "tool") |
| | else False |
| | ), |
| | ) |
| | else: |
| | if config.get("llava_pretrain", False): |
| | continue |
| |
|
| | import pdb |
| | import sys |
| |
|
| | class ForkedPdb(pdb.Pdb): |
| | """A Pdb subclass that may be used from a forked multiprocessing child""" |
| |
|
| | def interaction(self, *args, **kwargs): |
| | _stdin = sys.stdin |
| | try: |
| | sys.stdin = open("/dev/stdin") |
| | pdb.Pdb.interaction(self, *args, **kwargs) |
| | finally: |
| | sys.stdin = _stdin |
| |
|
| | ForkedPdb().set_trace() |
| | output = Preprocessor.prompt_etc( |
| | output=output, |
| | tokenizer=self.tokenizer, |
| | turn=turn, |
| | ) |
| |
|
| | pixel_values = [] |
| | mm_query_lengths = [] |
| | discrete_pixel_values = [] |
| | image_ratios = [] |
| | discrete_image_query_lengths = [] |
| |
|
| | labels = output.label_ids |
| | input_ids = output.input_ids |
| | total_mm_query_length = 0 |
| |
|
| | is_sft1 = False |
| | if config["data_type"] == "sft1": |
| | if self.sequence_parallel_size > 1: |
| | if len(input_ids) % self.sequence_parallel_size != 0: |
| | input_ids += [self.tokenizer.pad_token_id] * ( |
| | self.sequence_parallel_size |
| | - (len(input_ids) % self.sequence_parallel_size) |
| | ) |
| | labels += [IGNORE_INDEX] * ( |
| | self.sequence_parallel_size |
| | - (len(labels) % self.sequence_parallel_size) |
| | ) |
| |
|
| | input_ids = input_ids[ |
| | : (len(input_ids) // self.sequence_parallel_size) |
| | * self.sequence_parallel_size |
| | ] |
| | labels = labels[ |
| | : (len(labels) // self.sequence_parallel_size) |
| | * self.sequence_parallel_size |
| | ] |
| |
|
| | input_ids = torch.tensor(input_ids[-self.decoder_max_length :]) |
| | labels = torch.tensor(labels[-self.decoder_max_length :]) |
| | is_sft1 = True |
| |
|
| | dummy_preprocess_results = self.prepare_input_fn.image_processor( |
| | Image.new("RGB", (224, 224), (0, 0, 0)) |
| | ) |
| | dummy_pixel_values = torch.from_numpy( |
| | np.concatenate([dummy_preprocess_results.pixel_values], axis=0) |
| | ) |
| | dummy_grid_thw = torch.from_numpy( |
| | np.concatenate([dummy_preprocess_results.image_grid_thw], axis=0) |
| | ) |
| |
|
| | image_grid_thw = [] |
| | for img in output.imgs: |
| | w, h = img.size |
| |
|
| | img = self._resize_min_edge(img) |
| | preprocess_results = self.prepare_input_fn.image_processor([img]) |
| | pixel_values.append(preprocess_results.pixel_values) |
| | image_grid_thw.append(preprocess_results.image_grid_thw) |
| | mm_query_lengths.append(preprocess_results.pixel_values.shape[0] // 4) |
| |
|
| | if len(output.imgs) == 0: |
| | pixel_values = torch.zeros(0, 1176) |
| | image_grid_thw = torch.zeros(0, 3, dtype=torch.long) |
| | else: |
| | pixel_values = torch.from_numpy(np.concatenate(pixel_values, axis=0)) |
| | image_grid_thw = torch.from_numpy(np.concatenate(image_grid_thw, axis=0)) |
| |
|
| | for img in output.discrete_imgs: |
| | w, h = img.size |
| |
|
| | img_ratio = self._find_best_ratio_token([h, w]) |
| | image_ratios.append(img_ratio) |
| | discrete_pixel_value = img.resize((384, 384), Image.BICUBIC) |
| | discrete_pixel_tensor = to_tensor(discrete_pixel_value) |
| |
|
| | assert discrete_pixel_tensor.shape == ( |
| | 3, |
| | 384, |
| | 384, |
| | ), f"Unexpected discrete_pixel_tensor shape: {discrete_pixel_tensor.shape}" |
| | assert not torch.isnan( |
| | discrete_pixel_tensor |
| | ).any(), "discrete_pixel_tensor contains NaN" |
| | assert not torch.isinf( |
| | discrete_pixel_tensor |
| | ).any(), "discrete_pixel_tensor contains Inf" |
| | pixel_min = discrete_pixel_tensor.min().item() |
| | pixel_max = discrete_pixel_tensor.max().item() |
| | assert ( |
| | 0.0 <= pixel_min <= 1.0 and 0.0 <= pixel_max <= 1.0 |
| | ), f"discrete_pixel_tensor values out of range [0, 1]: min={pixel_min}, max={pixel_max}" |
| |
|
| | discrete_pixel_values.append(discrete_pixel_tensor) |
| | discrete_image_query_lengths.append(729) |
| |
|
| | if len(output.discrete_imgs) == 0: |
| | discrete_pixel_values = torch.zeros(0, 3, 384, 384) |
| | else: |
| | discrete_pixel_values = torch.stack(discrete_pixel_values, dim=0) |
| |
|
| | assert discrete_pixel_values.shape[1:] == ( |
| | 3, |
| | 384, |
| | 384, |
| | ), f"Unexpected stacked discrete_pixel_values shape: {discrete_pixel_values.shape}" |
| | assert not torch.isnan( |
| | discrete_pixel_values |
| | ).any(), "Stacked discrete_pixel_values contains NaN" |
| | assert not torch.isinf( |
| | discrete_pixel_values |
| | ).any(), "Stacked discrete_pixel_values contains Inf" |
| |
|
| | pixel_values_videos = None |
| | video_grid_thw = None |
| | if self.train_video: |
| | pixel_values_videos = [] |
| | video_grid_thw = [] |
| | video_query_lengths = [] |
| | for video in output.videos: |
| | preprocess_results = self.prepare_input_fn.video_processor([video]) |
| | pixel_values_videos.append(preprocess_results.pixel_values_videos) |
| | video_grid_thw.append(preprocess_results.video_grid_thw) |
| | video_query_lengths.append( |
| | preprocess_results.pixel_values_videos.shape[0] // 4 |
| | ) |
| | if len(output.videos) == 0: |
| | pixel_values_videos = torch.zeros(0, 1176) |
| | video_grid_thw = torch.zeros(0, 3, dtype=torch.long) |
| | else: |
| | pixel_values_videos = torch.from_numpy( |
| | np.concatenate(pixel_values_videos, axis=0) |
| | ) |
| | video_grid_thw = torch.from_numpy( |
| | np.concatenate(video_grid_thw, axis=0) |
| | ) |
| |
|
| | video_audio_values = [] |
| | video_audio_masks = [] |
| | video_audio_query_lengths = [] |
| | if self.train_video and hasattr(output, "video_audios") and output.video_audios: |
| | for idx, video_audio_chunks in enumerate(output.video_audios): |
| | if video_audio_chunks: |
| | processed_audio_values = [] |
| | processed_audio_masks = [] |
| | chunk_output_lengths = [] |
| |
|
| | for chunk in video_audio_chunks: |
| | if isinstance(chunk, torch.Tensor): |
| | chunk_np = chunk.cpu().numpy() |
| | else: |
| | chunk_np = chunk |
| |
|
| | preprocess_results = self.prepare_audio_input_fn( |
| | [chunk_np], |
| | sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
| | return_attention_mask=True, |
| | padding="max_length", |
| | ) |
| |
|
| | audio_value = preprocess_results.input_features[0] |
| | audio_mask = preprocess_results.attention_mask[0] |
| |
|
| | mask_sum = int(audio_mask.sum()) |
| | input_lengths = (mask_sum - 1) // 2 + 1 |
| | output_lengths = (input_lengths - 2) // 2 + 1 |
| | chunk_output_lengths.append(output_lengths) |
| |
|
| | processed_audio_values.append(torch.from_numpy(audio_value)) |
| | processed_audio_masks.append(torch.from_numpy(audio_mask)) |
| |
|
| | pool_size = 25 |
| | if self.video_audio_compressor_type is not None: |
| | total_valid_len = sum(chunk_output_lengths) |
| | total_audio_query_length = ( |
| | total_valid_len + pool_size - 1 |
| | ) // pool_size |
| | else: |
| | total_audio_query_length = sum( |
| | (valid_len + pool_size - 1) // pool_size |
| | for valid_len in chunk_output_lengths |
| | ) |
| |
|
| | video_audio_values.append(processed_audio_values) |
| | video_audio_masks.append(processed_audio_masks) |
| | video_audio_query_lengths.append(total_audio_query_length) |
| |
|
| | import os |
| |
|
| | if ( |
| | int(os.environ.get("RANK", -1)) == 0 |
| | and total_audio_query_length == 177 |
| | ): |
| | print( |
| | f"\n[PREPROCESSOR VIDEO - 177 TOKENS DETECTED!] total_audio_query_length={total_audio_query_length}, num_chunks={len(processed_audio_masks)}" |
| | ) |
| | for chunk_idx, mask_tensor in enumerate(processed_audio_masks): |
| | chunk_mask_sum = int(mask_tensor.sum()) |
| | chunk_input_len = (chunk_mask_sum - 1) // 2 + 1 |
| | chunk_output_len = (chunk_input_len - 2) // 2 + 1 |
| | chunk_pooled = (chunk_output_len + 24) // 25 |
| | print( |
| | f" Chunk {chunk_idx}: mask_sum={chunk_mask_sum}, output_len={chunk_output_len}, pooled={chunk_pooled}" |
| | ) |
| | print() |
| |
|
| | else: |
| | video_audio_values.append([]) |
| | video_audio_masks.append([]) |
| | video_audio_query_lengths.append(0) |
| |
|
| | dummy_video_preprocess_results = self.prepare_input_fn.video_processor( |
| | [Image.new("RGB", (224, 224), (0, 0, 0))] * 3 |
| | ) |
| | dummy_pixel_values_videos = torch.from_numpy( |
| | np.concatenate([dummy_video_preprocess_results.pixel_values_videos], axis=0) |
| | ) |
| | dummy_video_grid_thw = torch.from_numpy( |
| | np.concatenate([dummy_video_preprocess_results.video_grid_thw], axis=0) |
| | ) |
| | dummy_video_preprocess_results = self.prepare_audio_input_fn( |
| | [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], |
| | sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
| | return_attention_mask=True, |
| | padding="max_length", |
| | ) |
| | dummy_video_audio_values = torch.from_numpy( |
| | dummy_video_preprocess_results.input_features |
| | ) |
| | dummy_video_audio_masks = torch.from_numpy( |
| | dummy_video_preprocess_results.attention_mask |
| | ) |
| |
|
| | audio_values = None |
| | discrete_audio_values = None |
| | audio_masks = None |
| | dummy_preprocess_results = self.prepare_audio_input_fn( |
| | [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], |
| | sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
| | return_attention_mask=True, |
| | padding="max_length", |
| | ) |
| | dummy_audio_values = torch.from_numpy(dummy_preprocess_results.input_features) |
| | dummy_audio_masks = torch.from_numpy(dummy_preprocess_results.attention_mask) |
| | if self.train_audio: |
| | audio_values = [] |
| | discrete_audio_values = [] |
| | audio_masks = [] |
| | audio_query_lengths = [] |
| | discrete_audio_query_lengths = [] |
| |
|
| | if len(output.audios) > 99: |
| | raise ConditionalError( |
| | f"Too many audio segments in one sample: {len(output.audios)} audios." |
| | ) |
| |
|
| | for audio in output.audios: |
| | chunks = [] |
| | for i in range( |
| | 0, len(audio), 30 * self.prepare_audio_input_fn.sampling_rate |
| | ): |
| | chunks.append( |
| | audio[i : i + 30 * self.prepare_audio_input_fn.sampling_rate] |
| | ) |
| | num_of_chunks = len(chunks) |
| | preprocess_results = self.prepare_audio_input_fn( |
| | chunks, |
| | sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
| | return_attention_mask=True, |
| | padding="max_length", |
| | ) |
| | audio_value = preprocess_results.input_features |
| | audio_mask = preprocess_results.attention_mask |
| | audio_values.append(audio_value) |
| | audio_masks.append(audio_mask) |
| | input_lengths = int(audio_mask.sum()) |
| | input_lengths = (input_lengths - 1) // 2 + 1 |
| | output_lengths = (input_lengths - 2) // 2 + 1 |
| | audio_query_lengths.append(output_lengths) |
| |
|
| | if len(output.audios) == 0: |
| | audio_values = torch.zeros(0, 128, 3000) |
| | audio_masks = torch.zeros(0, 3000) |
| | else: |
| | audio_values = torch.from_numpy(np.concatenate(audio_values, axis=0)) |
| | audio_masks = torch.from_numpy(np.concatenate(audio_masks, axis=0)) |
| |
|
| | for audio in output.discrete_audios: |
| | audio_length = len(audio) |
| |
|
| | assert audio_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( |
| | f"discrete_audio is too short ({audio_length} samples < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}). " |
| | f"This will cause 0-dim/empty tensor in CosyVoice encoder. " |
| | f"Skip this sample." |
| | ) |
| |
|
| | max_audio_length = 600 * DEFAULT_SAMPLE_RATE |
| | audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE |
| | assert ( |
| | audio_length <= max_audio_length |
| | ), f"discrete_audio is too long ({audio_length} samples = {audio_duration_sec:.1f}s > 600s). " |
| |
|
| | assert not torch.isnan(audio).any(), ( |
| | f"discrete_audio contains NaN values! " |
| | f"This will cause CUDA illegal memory access. Skip this sample." |
| | ) |
| | assert not torch.isinf(audio).any(), ( |
| | f"discrete_audio contains Inf values! " |
| | f"This will cause CUDA illegal memory access. Skip this sample." |
| | ) |
| |
|
| | audio_min, audio_max = audio.min().item(), audio.max().item() |
| | assert -100.0 <= audio_min <= 100.0 and -100.0 <= audio_max <= 100.0, ( |
| | f"discrete_audio has extreme values (min={audio_min:.2f}, max={audio_max:.2f}). " |
| | f"Expected roughly [-1, 1] range. This indicates corrupted audio. Skip this sample." |
| | ) |
| |
|
| | discrete_audio_values.append(audio) |
| |
|
| | if audio_length > 80 * DEFAULT_SAMPLE_RATE: |
| | chunk_size = 80 * DEFAULT_SAMPLE_RATE |
| |
|
| | total_code_len = 0 |
| |
|
| | for start in range(0, audio_length, chunk_size): |
| | end = min(start + chunk_size, audio_length) |
| |
|
| | if ( |
| | end < audio_length |
| | and audio_length - end < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES |
| | ): |
| | end = audio_length |
| |
|
| | chunk_length = end - start |
| |
|
| | assert chunk_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( |
| | f"chunk_length={chunk_length} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}. This should never happen with our chunking logic. " |
| | f"audio_length={audio_length}, start={start}, end={end}. Skip this sample." |
| | ) |
| |
|
| | mel_len = chunk_length // 160 |
| |
|
| | assert mel_len > 0, ( |
| | f"mel_len={mel_len} is invalid (chunk_length={chunk_length}). " |
| | f"This will cause illegal memory access in AudioEncoder. Skip this sample." |
| | ) |
| |
|
| | after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
| | code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
| |
|
| | assert code_len > 0, ( |
| | f"code_len={code_len} is invalid (mel_len={mel_len}, after_conv1={after_conv1}). " |
| | f"This will cause illegal memory access. Skip this sample." |
| | ) |
| |
|
| | total_code_len += code_len |
| |
|
| | if end >= audio_length: |
| | break |
| |
|
| | assert total_code_len > 0, ( |
| | f"total_code_len={total_code_len} is invalid after processing all chunks. " |
| | f"audio_length={audio_length}. This should never happen. Skip this sample." |
| | ) |
| |
|
| | audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE |
| | max_expected_codes = int(audio_duration_sec * 25 * 1.1) |
| | assert total_code_len <= max_expected_codes, ( |
| | f"total_code_len={total_code_len} is suspiciously large (max_expected={max_expected_codes}). " |
| | f"audio_length={audio_length} ({audio_duration_sec:.1f}s). " |
| | f"Expected ~{int(audio_duration_sec * 25)} tokens (25 tokens/sec). " |
| | f"This indicates calculation error. Skip this sample." |
| | ) |
| |
|
| | discrete_audio_query_lengths.append(total_code_len) |
| | else: |
| | mel_len = audio_length // 160 |
| |
|
| | assert mel_len > 0, ( |
| | f"mel_len={mel_len} is invalid (audio_length={audio_length}). " |
| | f"This will cause illegal memory access in AudioEncoder. Skip this sample." |
| | ) |
| |
|
| | after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
| | code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
| |
|
| | assert code_len > 0, ( |
| | f"Calculated code_len={code_len} is invalid (audio_length={audio_length}, " |
| | f"mel_len={mel_len}, after_conv1={after_conv1}). " |
| | f"This indicates corrupted audio data. Skip this sample." |
| | ) |
| |
|
| | assert code_len <= 2048, ( |
| | f"code_len={code_len} exceeds freqs_cis max length (2048). " |
| | f"Audio length: {audio_length / DEFAULT_SAMPLE_RATE:.1f}s (max ~82s for single chunk). " |
| | f"Expected ~{int((audio_length / DEFAULT_SAMPLE_RATE) * 25)} tokens at 25 tokens/sec. " |
| | f"This will cause illegal memory access in apply_rotary_emb. Skip this sample." |
| | ) |
| |
|
| | discrete_audio_query_lengths.append(code_len) |
| |
|
| | img_start_ids = [ |
| | i for i, token in enumerate(input_ids) if token == self.img_token |
| | ] |
| | assert len(img_start_ids) == len(mm_query_lengths) |
| | for i, length in zip( |
| | range(len(mm_query_lengths) - 1, -1, -1), mm_query_lengths[::-1] |
| | ): |
| | labels[img_start_ids[i] : img_start_ids[i] + 1] = [IGNORE_INDEX] * length |
| | input_ids[img_start_ids[i] : img_start_ids[i] + 1] = [ |
| | self.img_token |
| | ] * length |
| | total_mm_query_length += length |
| |
|
| | discrete_image_start_ids = [ |
| | i for i, token in enumerate(input_ids) if token == self.discrete_image_token |
| | ] |
| | assert len(discrete_image_start_ids) == len(discrete_image_query_lengths) |
| | assert len(discrete_image_start_ids) == len( |
| | image_ratios |
| | ), "discrete_image_start_ids and image_ratios length mismatch" |
| |
|
| | for idx in range(len(discrete_image_query_lengths) - 1, -1, -1): |
| | i = discrete_image_start_ids[idx] |
| | length = discrete_image_query_lengths[idx] |
| | ratio_token_id = image_ratios[idx] |
| | assert ( |
| | length == 729 |
| | ), f"discrete_image_query_length must be 729, but got {length}" |
| |
|
| | token_sequence = [ratio_token_id] |
| | for token_idx in range(length): |
| | token_sequence.append(self.discrete_image_token) |
| | if (token_idx + 1) % 27 == 0: |
| | token_sequence.append(self.discrete_image_eol_token) |
| | token_sequence.append(self.discrete_image_eof_token) |
| |
|
| | total_length = len(token_sequence) |
| | if labels[i] == IGNORE_INDEX: |
| | labels[i : i + 1] = [IGNORE_INDEX] * total_length |
| | else: |
| | labels[i : i + 1] = token_sequence |
| | input_ids[i : i + 1] = token_sequence |
| |
|
| | if self.train_video: |
| | vid_start_ids = [ |
| | i for i, token in enumerate(input_ids) if token == self.video_token |
| | ] |
| |
|
| | for idx in range(len(vid_start_ids) - 1, -1, -1): |
| | pos = vid_start_ids[idx] |
| |
|
| | num_frames = int(video_grid_thw[idx][0]) |
| | frame_query_length = video_query_lengths[idx] |
| |
|
| | has_video_audio = ( |
| | idx < len(video_audio_query_lengths) |
| | and video_audio_query_lengths[idx] > 0 |
| | ) |
| |
|
| | if has_video_audio: |
| | total_audio_tokens = video_audio_query_lengths[idx] |
| |
|
| | token_sequence = [] |
| |
|
| | if num_frames > 0: |
| |
|
| | frame_base = frame_query_length // num_frames |
| | frame_remainder = frame_query_length % num_frames |
| |
|
| | assert frame_remainder == 0, ( |
| | f"frame_query_length({frame_query_length}) must be divisible by num_frames({num_frames}). " |
| | f"Each frame produces fixed number of tokens. Got remainder={frame_remainder}." |
| | ) |
| |
|
| | audio_base = total_audio_tokens // num_frames |
| | audio_remainder = total_audio_tokens % num_frames |
| |
|
| | for frame_idx in range(num_frames): |
| | frame_tokens = frame_base + ( |
| | 1 if frame_idx < frame_remainder else 0 |
| | ) |
| | token_sequence.extend([self.video_token] * frame_tokens) |
| |
|
| | audio_tokens = audio_base + ( |
| | 1 if frame_idx < audio_remainder else 0 |
| | ) |
| | if audio_tokens > 0: |
| | token_sequence.extend( |
| | [self.video_audio_token] * audio_tokens |
| | ) |
| | else: |
| | token_sequence = [self.video_token] * frame_query_length |
| | else: |
| | token_sequence = [self.video_token] * frame_query_length |
| |
|
| | total_length = len(token_sequence) |
| | labels[pos : pos + 1] = [IGNORE_INDEX] * total_length |
| | input_ids[pos : pos + 1] = token_sequence |
| |
|
| | if self.train_audio: |
| | audio_start_ids = [ |
| | i for i, token in enumerate(input_ids) if token == self.audio_token |
| | ] |
| | assert len(audio_start_ids) == len(audio_query_lengths) |
| | for i, length in zip( |
| | range(len(audio_query_lengths) - 1, -1, -1), audio_query_lengths[::-1] |
| | ): |
| | labels[audio_start_ids[i] : audio_start_ids[i] + 1] = [ |
| | IGNORE_INDEX |
| | ] * length |
| | input_ids[audio_start_ids[i] : audio_start_ids[i] + 1] = [ |
| | self.audio_token |
| | ] * length |
| |
|
| | discrete_audio_start_ids = [ |
| | i |
| | for i, token in enumerate(input_ids) |
| | if token == self.discrete_audio_token |
| | ] |
| |
|
| | assert len(discrete_audio_start_ids) == len(discrete_audio_query_lengths), ( |
| | f"discrete_audio_start_ids count ({len(discrete_audio_start_ids)}) != " |
| | f"discrete_audio_query_lengths count ({len(discrete_audio_query_lengths)}). " |
| | f"This indicates a serious bug in preprocessor or corrupted data. Skip this sample." |
| | ) |
| |
|
| | for i, length in zip( |
| | range(len(discrete_audio_query_lengths) - 1, -1, -1), |
| | discrete_audio_query_lengths[::-1], |
| | ): |
| | assert 0 < length < 16000, ( |
| | f"discrete_audio_query_length={length} is out of valid range [1, 16000). " |
| | f"Expected max ~15,000 for 600s audio at 25 tokens/sec. " |
| | f"This can cause illegal memory access when creating embeddings. Skip this sample." |
| | ) |
| |
|
| | if labels[discrete_audio_start_ids[i]] == IGNORE_INDEX: |
| | labels[ |
| | discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
| | ] = [IGNORE_INDEX] * length |
| | else: |
| | labels[ |
| | discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
| | ] = [self.discrete_audio_token] * length |
| | input_ids[ |
| | discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
| | ] = [self.discrete_audio_token] * length |
| |
|
| | if self.sequence_parallel_size > 1: |
| | if len(input_ids) % self.sequence_parallel_size != 0: |
| | input_ids += [self.tokenizer.pad_token_id] * ( |
| | self.sequence_parallel_size |
| | - (len(input_ids) % self.sequence_parallel_size) |
| | ) |
| | labels += [IGNORE_INDEX] * ( |
| | self.sequence_parallel_size |
| | - (len(labels) % self.sequence_parallel_size) |
| | ) |
| |
|
| | if not is_sft1: |
| | input_ids = torch.tensor(input_ids) |
| | labels = torch.tensor(labels) |
| |
|
| | if self.mode == "train": |
| | if self.sample_min_length is not None and self.sample_min_length > 0: |
| | assert ( |
| | len(labels) >= self.sample_min_length |
| | ), "The sample is too short: {} < {}".format( |
| | len(labels), self.sample_min_length |
| | ) |
| | assert ( |
| | len(labels) <= self.decoder_max_length |
| | ), "The sample exceeds decoder_max_len: {} > {}".format( |
| | len(labels), self.decoder_max_length |
| | ) |
| | assert len(input_ids) == len(labels) |
| |
|
| | if len(labels) < 30: |
| | raise ConditionalError( |
| | "The sample is too short: {}".format(len(labels)) |
| | ) |
| |
|
| | if torch.all(labels == IGNORE_INDEX): |
| | raise ConditionalError( |
| | "Labels contain only IGNORE_INDEX, no training targets available" |
| | ) |
| |
|
| | sample = { |
| | "pixel_values": pixel_values, |
| | "discrete_pixel_values": discrete_pixel_values, |
| | "idx_for_debug": idx_for_debug, |
| | "input_ids": input_ids, |
| | "labels": labels, |
| | "queries": queries if len(queries) > 0 else None, |
| | "gts": gts if len(gts) > 0 else None, |
| | "mm_query_lengths": mm_query_lengths, |
| | "non_mm_query_lengths": len(labels) - total_mm_query_length, |
| | "total_length": len(labels), |
| | "data_name": config["data_name"], |
| | "data_type": config["data_type"], |
| | "img_start_ids": img_start_ids, |
| | "prompt": output.input_str, |
| | "options": config.get("options", None), |
| | "image_grid_thw": image_grid_thw, |
| | "pixel_values_videos": pixel_values_videos, |
| | "video_grid_thw": video_grid_thw, |
| | "video_audio_values": ( |
| | video_audio_values if len(video_audio_values) > 0 else None |
| | ), |
| | "video_audio_masks": ( |
| | video_audio_masks if len(video_audio_masks) > 0 else None |
| | ), |
| | "audio_values": audio_values, |
| | "discrete_audio_values": discrete_audio_values, |
| | "audio_masks": audio_masks, |
| | "dummy_pixel_values": dummy_pixel_values, |
| | "dummy_grid_thw": dummy_grid_thw, |
| | "dummy_audio_values": dummy_audio_values, |
| | "dummy_audio_masks": dummy_audio_masks, |
| | "dummy_pixel_values_videos": dummy_pixel_values_videos, |
| | "dummy_video_grid_thw": dummy_video_grid_thw, |
| | "dummy_video_audio_values": dummy_video_audio_values, |
| | "dummy_video_audio_masks": dummy_video_audio_masks, |
| | } |
| |
|
| | return sample |
| |
|
| | def _sampling_multiturn( |
| | self, |
| | turns, |
| | n_sample, |
| | multiturn_preserve_order=True, |
| | multiturn_continuous=False, |
| | ): |
| | new_turns = [] |
| | sample_indices = [] |
| | first_user_turn = True |
| | start_idx = 0 |
| | for idx, turn in enumerate(turns): |
| | if turn["role"] in ["system", "tool_list"]: |
| | new_turns.append(turn) |
| | start_idx = idx + 1 |
| | continue |
| | if turn["role"] == "user": |
| | image_nums = re.findall(r"<image_(\d+)>", turn["content"]) |
| | if len(image_nums) == 0: |
| | image_nums = re.findall(r"<\|image\|>", turn["content"]) |
| | if len(image_nums) > 0: |
| | if first_user_turn: |
| | first_user_turn = False |
| | continue |
| | sample_indices.append([i for i in range(start_idx, idx)]) |
| | start_idx = idx |
| | sample_indices.append([i for i in range(start_idx, idx + 1)]) |
| | n_sample = min(n_sample, len(sample_indices)) |
| | if multiturn_continuous: |
| | start_index = random.randint(0, len(sample_indices) - n_sample) |
| | indices = range(start_index, start_index + n_sample) |
| | elif multiturn_preserve_order: |
| | indices = sorted(random.sample(range(len(sample_indices)), n_sample)) |
| | else: |
| | indices = random.sample(range(len(sample_indices)), n_sample) |
| | sampled_indices = [sample_indices[i] for i in indices] |
| | new_turns = new_turns + [ |
| | turns[i] for sampled_turns in sampled_indices for i in sampled_turns |
| | ] |
| | return new_turns |
| |
|