| | import copy |
| | import os |
| | from dataclasses import dataclass, field |
| | from typing import Dict |
| | import torch |
| | import transformers |
| | import ujson as json |
| | from torch.utils.data import Dataset |
| | from qwen_vl_utils import process_vision_info |
| | from PIL import Image |
| | import re |
| | import yaml |
| | import random |
| | import math |
| | import pprint |
| |
|
| | from .params import DataArguments |
| | from .constants import * |
| |
|
| |
|
| | def truncate_sequence(input_ids, labels, max_length, eos_token_id): |
| | if input_ids.size(0) > max_length: |
| | input_ids = input_ids[:max_length-1] |
| | labels = labels[:max_length-1] |
| |
|
| | if eos_token_id is not None: |
| | input_ids = torch.cat([input_ids, torch.tensor([eos_token_id])]) |
| | labels = torch.cat([labels, torch.tensor([eos_token_id])]) |
| |
|
| | return input_ids, labels |
| |
|
| | def pad_sequence(sequences, padding_side='right', padding_value=0): |
| | """ |
| | Pad a list of sequences to the same length. |
| | sequences: list of tensors in [seq_len, *] shape |
| | """ |
| | assert padding_side in ['right', 'left'] |
| | max_size = sequences[0].size() |
| | trailing_dims = max_size[1:] |
| | max_len = max(len(seq) for seq in sequences) |
| | batch_size = len(sequences) |
| | output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
| | for i, seq in enumerate(sequences): |
| | length = seq.size(0) |
| | if padding_side == 'right': |
| | output.data[i, :length] = seq |
| | else: |
| | output.data[i, -length:] = seq |
| | return output |
| |
|
| | def get_image_info(image_path, min_pixel, max_pixel): |
| | |
| | |
| | |
| | messages = [ |
| | {"role": "user", |
| | "content": [ |
| | { |
| | "type": "image", |
| | "image": image_path, |
| | "min_pixel": min_pixel, |
| | "max_pixel": max_pixel |
| |
|
| | } |
| | ] |
| | } |
| | ] |
| |
|
| | image_input, _ = process_vision_info(messages) |
| |
|
| | return image_input[0] |
| |
|
| | def get_video_info(video_path, min_pixels, max_pixels, fps): |
| | |
| | |
| |
|
| | messages = [ |
| | {"role": "user", |
| | "content": [ |
| | { |
| | "type": "video", |
| | "video": video_path, |
| | "min_pixels": min_pixels, |
| | "max_pixels": max_pixels, |
| | "fps": fps |
| | } |
| | ] |
| | } |
| | ] |
| |
|
| | _, video_input, video_kwargs = process_vision_info(messages, return_video_kwargs=True) |
| |
|
| | return video_input[0], video_kwargs |
| |
|
| | class SupervisedDataset(Dataset): |
| | """Dataset for supervised fine-tuning.""" |
| |
|
| | def __init__( |
| | self, |
| | data_path: str | list, |
| | processor: transformers.ProcessorMixin, |
| | data_args: DataArguments, |
| | model_id, |
| | padding=True, |
| | ): |
| | super(SupervisedDataset, self).__init__() |
| | if isinstance(data_path, str): |
| | if data_path.endswith(".json"): |
| | list_data_dict = json.load(open(data_path, "r")) |
| | |
| | elif data_path.endswith(".yaml"): |
| | list_data_dict = [] |
| | with open(data_path, "r") as file: |
| | yaml_data = yaml.safe_load(file) |
| | pprint.pprint(yaml_data) |
| | datasets = yaml_data.get("datasets") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] |
| | for dataset in datasets: |
| | json_path = dataset.get("json_path") |
| | sampling_strategy = dataset.get("sampling_strategy", "all") |
| | sampling_number = None |
| |
|
| | print(f"Loading {json_path} with {sampling_strategy} sampling strategy") |
| |
|
| | if json_path.endswith(".jsonl"): |
| | cur_data_dict = [] |
| | with open(json_path, "r") as json_file: |
| | for line in json_file: |
| | cur_data_dict.append(json.loads(line.strip())) |
| | elif json_path.endswith(".json"): |
| | with open(json_path, "r") as json_file: |
| | cur_data_dict = json.load(json_file) |
| | else: |
| | raise ValueError(f"Unsupported file type: {json_path}") |
| |
|
| | if ":" in sampling_strategy: |
| | sampling_strategy, sampling_number = sampling_strategy.split(":") |
| | if "%" in sampling_number: |
| | sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) |
| | else: |
| | sampling_number = int(sampling_number) |
| |
|
| | |
| | if sampling_strategy == "first" and sampling_number is not None: |
| | cur_data_dict = cur_data_dict[:sampling_number] |
| | elif sampling_strategy == "end" and sampling_number is not None: |
| | cur_data_dict = cur_data_dict[-sampling_number:] |
| | elif sampling_strategy == "random" and sampling_number is not None: |
| | random.shuffle(cur_data_dict) |
| | cur_data_dict = cur_data_dict[:sampling_number] |
| |
|
| | print(f"Loaded {len(cur_data_dict)} samples from {json_path}") |
| | list_data_dict.extend(cur_data_dict) |
| | print(f"Loaded {len(list_data_dict)} samples from {data_path} in total") |
| | else: |
| | list_data_dict = data_path |
| |
|
| | self.model_id = model_id |
| | self.processor = processor |
| | self.list_data_dict = list_data_dict |
| | self.data_args = data_args |
| | self.padding = padding |
| | self.image_min_pixel = data_args.image_min_pixels |
| | self.image_max_pixel = data_args.image_max_pixels |
| | self.video_min_pixel = data_args.video_min_pixels |
| | self.video_max_pixel = data_args.video_max_pixels |
| | self.fps = data_args.fps |
| |
|
| | def __len__(self): |
| | return len(self.list_data_dict) |
| |
|
| | def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| | sources = self.list_data_dict[i] |
| |
|
| | is_video = False |
| |
|
| | processor = self.processor |
| | if "image" in sources: |
| | videos = None |
| | grid_key = "image_grid_thw" |
| | pixel_key = "pixel_values" |
| | |
| | image_files = sources["image"] |
| | image_folder = self.data_args.image_folder |
| |
|
| | if isinstance(image_files, str): |
| | image_files = [image_files] |
| |
|
| | images = [] |
| | |
| | for image_file in image_files: |
| | if not os.path.exists(image_file): |
| | if not image_file.startswith("http"): |
| | if 'share' in image_file: |
| | image_file = image_file.split('share/world_model/')[1] |
| | image_file = os.path.join(image_folder, image_file) |
| | images.append(get_image_info(image_file, self.image_min_pixel, self.image_max_pixel)) |
| |
|
| | elif "video" in sources: |
| | is_video = True |
| | images=None |
| | grid_key = "video_grid_thw" |
| | pixel_key = "pixel_values_videos" |
| |
|
| | video_files = sources["video"] |
| | video_folder = self.data_args.image_folder |
| |
|
| | if isinstance(video_files, str): |
| | video_files = [video_files] |
| |
|
| | videos = [] |
| | for video_file in video_files: |
| | if not os.path.exists(video_file): |
| | if not video_file.startswith("http"): |
| | if 'share' in video_file: |
| | video_file = video_file.split('share/world_model/')[1] |
| | video_file = os.path.join(video_folder, video_file) |
| | video_input, video_kwargs = get_video_info(video_file, self.video_min_pixel, self.video_max_pixel, self.data_args.fps) |
| | videos.append(video_input) |
| | else: |
| | grid_key = None |
| | pixel_key = None |
| | images=None |
| | videos=None |
| |
|
| | sources = copy.deepcopy(llava_to_openai(sources['conversations'], is_video=is_video)) |
| |
|
| | all_input_ids = [] |
| | all_labels = [] |
| | all_pixel_values = [] |
| | all_image_grid_thw = [] |
| | all_second_gird = [] |
| |
|
| | |
| | if len(SYSTEM_MESSAGE) > 0: |
| | system_message = f"{DEFAULT_IM_START_TOKEN}system\n{SYSTEM_MESSAGE}\n{DEFAULT_IM_END_TOKEN}\n" |
| | system_message_input_ids = processor.tokenizer(system_message, add_special_tokens=False, return_tensors='pt')['input_ids'] |
| | system_labels = torch.full_like(system_message_input_ids, IGNORE_INDEX) |
| | |
| | all_input_ids.append(system_message_input_ids.squeeze(0)) |
| | all_labels.append(system_labels.squeeze(0)) |
| |
|
| | for _, j in enumerate(range(0, len(sources), 2)): |
| | user_input = sources[j] |
| | gpt_response = sources[j + 1] |
| |
|
| | user_input = f"{DEFAULT_IM_START_TOKEN}{user_input['role']}\n{user_input['content']}\n{DEFAULT_IM_END_TOKEN}\n{DEFAULT_IM_START_TOKEN}{gpt_response['role']}\n" |
| | gpt_response = f"{gpt_response['content']}\n{DEFAULT_IM_END_TOKEN}\n" |
| | |
| | if DEFAULT_IMAGE_TOKEN in user_input: |
| | inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
| | prompt_input_ids = inputs['input_ids'] |
| | all_pixel_values.append(inputs[pixel_key]) |
| | all_image_grid_thw.append(inputs[grid_key]) |
| | |
| | elif DEFAULT_VIDEO_TOKEN in user_input: |
| | if "Qwen2.5" in self.model_id: |
| | inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt', **video_kwargs) |
| | all_second_gird.extend(inputs["second_per_grid_ts"]) |
| | else: |
| | inputs = processor(text=[user_input], images=images, videos=videos, padding=False, return_tensors='pt') |
| | prompt_input_ids = inputs['input_ids'] |
| | all_pixel_values.append(inputs[pixel_key]) |
| | all_image_grid_thw.append(inputs[grid_key]) |
| |
|
| | else: |
| | prompt_input_ids = processor.tokenizer(user_input, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
| |
|
| | response_input_ids = processor.tokenizer(gpt_response, add_special_tokens=False, padding=False, return_tensors='pt')['input_ids'] |
| |
|
| | input_ids = torch.cat([prompt_input_ids, response_input_ids], dim=1).squeeze(0) |
| | labels = torch.cat( |
| | [ |
| | torch.tensor([IGNORE_INDEX] * len(prompt_input_ids[0])), |
| | response_input_ids.squeeze(0), |
| | ], |
| | dim=0, |
| | ) |
| |
|
| | all_input_ids.append(input_ids) |
| | all_labels.append(labels) |
| | |
| | |
| | |
| | input_ids = torch.cat(all_input_ids, dim=0).to(torch.long) |
| | labels = torch.cat(all_labels, dim=0).to(torch.long) |
| |
|
| | |
| | |
| |
|
| | attention_mask = (input_ids > -1000000).to(torch.long) |
| |
|
| | data_dict = dict( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | labels=labels, |
| | ) |
| |
|
| | if pixel_key and grid_key: |
| | pixel_values = torch.cat(all_pixel_values, dim=0) |
| | image_thw = torch.cat(all_image_grid_thw, dim=0) |
| | data_dict[pixel_key] = pixel_values |
| | data_dict[grid_key] = image_thw |
| |
|
| | if len(all_second_gird) > 0: |
| | second_gird = all_second_gird |
| | data_dict["second_per_grid_ts"] = second_gird |
| | |
| | return data_dict |
| |
|
| | class DataCollatorForSupervisedDataset(object): |
| | """Collate examples for supervised fine-tuning.""" |
| |
|
| | def __init__(self, pad_token_id: int): |
| | self.pad_token_id = pad_token_id |
| |
|
| | def __call__(self, examples): |
| | batch_input_ids = [] |
| | batch_label_ids = [] |
| | batch_pixel_values = [] |
| | batch_pixel_video_values = [] |
| | batch_video_thw = [] |
| | batch_image_thw = [] |
| | batch_second_per_grid_ts = [] |
| | |
| | for example in examples: |
| | keys = example.keys() |
| | if "pixel_values_videos" in keys: |
| | batch_pixel_video_values.append(example["pixel_values_videos"]) |
| | batch_video_thw.append(example["video_grid_thw"]) |
| | elif "pixel_values" in keys: |
| | batch_pixel_values.append(example["pixel_values"]) |
| | batch_image_thw.append(example["image_grid_thw"]) |
| | |
| | batch_input_ids.append(example["input_ids"]) |
| | batch_label_ids.append(example["labels"]) |
| |
|
| | if "second_per_grid_ts" in keys: |
| | batch_second_per_grid_ts.extend(example["second_per_grid_ts"]) |
| | |
| | input_ids = pad_sequence( |
| | batch_input_ids, padding_side='right', padding_value=self.pad_token_id |
| | ) |
| |
|
| | attention_mask = input_ids != self.pad_token_id |
| | labels = pad_sequence(batch_label_ids, padding_side='right', padding_value=IGNORE_INDEX) |
| |
|
| | data_dict = { |
| | 'input_ids': input_ids, |
| | 'labels': labels, |
| | 'attention_mask': attention_mask, |
| | } |
| |
|
| | if len(batch_pixel_values) > 0: |
| | pixel_values = torch.cat(batch_pixel_values, dim=0) |
| | image_thw = torch.cat(batch_image_thw, dim=0) |
| | data_dict["pixel_values"] = pixel_values |
| | data_dict["image_grid_thw"] = image_thw |
| |
|
| | if len(batch_pixel_video_values) > 0: |
| | pixel_video_values = torch.cat(batch_pixel_video_values, dim=0) |
| | video_thw = torch.cat(batch_video_thw, dim=0) |
| | data_dict["pixel_values_videos"] = pixel_video_values |
| | data_dict["video_grid_thw"] = video_thw |
| |
|
| | if len(batch_second_per_grid_ts) > 0: |
| | data_dict["second_per_grid_ts"] = batch_second_per_grid_ts |
| |
|
| | return data_dict |
| | |
| |
|
| | def replace_image_tokens(input_string, is_video=False): |
| | if is_video: |
| | pattern = r'\n?' + re.escape(LLAVA_VIDEO_TOKEN) + r'\n?' |
| | replacement = VISION_START_TOKEN + DEFAULT_VIDEO_TOKEN + VISION_END_TOKEN |
| | else: |
| | pattern = r'\n?' + re.escape(LLAVA_IMAGE_TOKEN) + r'\n?' |
| | replacement = VISION_START_TOKEN + DEFAULT_IMAGE_TOKEN + VISION_END_TOKEN |
| |
|
| | return re.sub(pattern, replacement, input_string) |
| |
|
| | def llava_to_openai(conversations, is_video=False): |
| | role_mapping = {"human": "user", "gpt": "assistant"} |
| |
|
| | transformed_data = [] |
| | for conversation in conversations: |
| | transformed_content = replace_image_tokens(conversation["value"], is_video=is_video) |
| | transformed_entry = { |
| | "role": role_mapping.get(conversation["from"], conversation["from"]), |
| | "content": transformed_content, |
| | } |
| | transformed_data.append(transformed_entry) |
| |
|
| | return transformed_data |
| |
|
| | def make_supervised_data_module(model_id, processor, data_args): |
| | """Make dataset and collator for supervised fine-tuning.""" |
| | sft_dataset = SupervisedDataset( |
| | data_path=data_args.data_path, processor=processor, data_args=data_args, model_id=model_id |
| | ) |
| | data_collator = DataCollatorForSupervisedDataset(pad_token_id=processor.tokenizer.pad_token_id) |
| |
|
| | return dict(train_dataset=sft_dataset, |
| | eval_dataset=None, |
| | data_collator=data_collator) |