| import copy |
| import json |
| import math |
| import os |
| import random |
| import re |
| import ast |
| from typing import Dict |
|
|
| import torch |
| import transformers |
| import yaml |
| from qwen_vl_utils import smart_resize, process_vision_info |
| from torch.utils.data import Dataset |
|
|
| from gui_actor.constants import ( |
| IGNORE_INDEX, |
| DEFAULT_IMAGE_TOKEN, |
| DEFAULT_POINTER_START_TOKEN, |
| DEFAULT_POINTER_PAD_TOKEN, |
| DEFAULT_POINTER_END_TOKEN, |
| ACTION_PATTENS_XY, |
| ADDITIONAL_SPECIAL_TOKENS, |
| assistant_template, |
| chat_template, |
| grounding_system_message, |
| ) |
| from gui_actor.trainer import rank0_print |
|
|
|
|
| def reformat_coordinates(text): |
| """ |
| (1) Find all the coordinates in the text. |
| (2) Replace the coordinates with the special tokens. |
| (3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1]. |
| """ |
| epsilon = 0.001 |
| def adjust_coord(c): |
| """ |
| Adjust coordinate if it is too close to 0 or 1. |
| """ |
| if abs(c) < epsilon: |
| return epsilon |
| elif abs(c - 1) < epsilon: |
| return 1 - epsilon |
| return c |
|
|
| all_matches = [] |
| for pattern in ACTION_PATTENS_XY: |
| matches = list(re.finditer(pattern, text)) |
| for match in matches: |
| all_matches.append((match.start(), match.groups())) |
| if pattern == ACTION_PATTENS_XY[0]: |
| target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" |
| else: |
| target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" |
| text = re.sub( |
| pattern, |
| target_text, |
| text |
| ) |
| |
| coordinates = [] |
| all_matches.sort(key=lambda x: x[0]) |
| |
| for _, groups in all_matches: |
| |
| if len(groups) == 2: |
| x_str, y_str = groups |
| x = adjust_coord(ast.literal_eval(x_str)) |
| y = adjust_coord(ast.literal_eval(y_str)) |
| coordinates.append((x, y)) |
| |
| elif len(groups) == 4: |
| x1_str, y1_str, x2_str, y2_str = groups |
| x1 = adjust_coord(ast.literal_eval(x1_str)) |
| y1 = adjust_coord(ast.literal_eval(y1_str)) |
| x2 = adjust_coord(ast.literal_eval(x2_str)) |
| y2 = adjust_coord(ast.literal_eval(y2_str)) |
| coordinates.append((x1, y1)) |
| coordinates.append((x2, y2)) |
| |
| return text, coordinates |
|
|
| def get_token_index(image_processor, image, point_x, point_y): |
| """ |
| Get the index of the visual token that contains the point (x, y). |
| Args: |
| image_processor: the image processor |
| image: the image in PIL format |
| point_x: the x coordinate of the point, in [0, 1]. |
| point_y: the y coordinate of the point, in [0, 1]. |
| """ |
| if len(image) != 1: |
| raise ValueError(f"Expected 1 image, got {len(image)}") |
| |
| |
| image = image[0] |
| w, h = image.size |
| px, py = w * point_x, h * point_y |
| |
| |
| merge_patch_size = image_processor.patch_size * image_processor.merge_size |
| x_index = math.floor(px / merge_patch_size) |
| y_index = math.floor(py / merge_patch_size) |
| |
| visual_token_index = y_index * (w // merge_patch_size) + x_index |
|
|
| |
| return visual_token_index |
|
|
| def get_multi_patch_labels(image_processor, image, bbox_gt): |
| """ |
| Get the multi-patch labels for the bounding box. |
| Args: |
| image_processor: the image processor |
| image: the image in PIL format |
| bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1] |
| """ |
| if len(image) != 1: |
| raise ValueError(f"Expected 1 image, got {len(image)}") |
|
|
| |
| image = image[0] |
| w, h = image.size |
|
|
| bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h] |
| |
| x_min, y_min, x_max, y_max = bbox_gt |
| x_min = max(0, x_min) |
| y_min = max(0, y_min) |
| x_max = min(w, x_max) |
| y_max = min(h, y_max) |
|
|
| merge_patch_size = image_processor.patch_size * image_processor.merge_size |
| assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}" |
| grid_h, grid_w = h // merge_patch_size, w // merge_patch_size |
|
|
| binary_mask = torch.zeros(grid_h * grid_w) |
| |
| for y_idx in range(grid_h): |
| for x_idx in range(grid_w): |
| |
| patch_x_min = x_idx * merge_patch_size |
| patch_y_min = y_idx * merge_patch_size |
| patch_x_max = patch_x_min + merge_patch_size |
| patch_y_max = patch_y_min + merge_patch_size |
| |
| |
| if not (patch_x_max <= x_min or patch_x_min >= x_max or |
| patch_y_max <= y_min or patch_y_min >= y_max): |
| |
| patch_idx = y_idx * grid_w + x_idx |
| binary_mask[patch_idx] = 1 |
|
|
| return binary_mask |
|
|
| def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height): |
| merge_patch_size = image_processor.patch_size * image_processor.merge_size |
| x_index = visual_token_index % (image_width // merge_patch_size) |
| y_index = visual_token_index // (image_width // merge_patch_size) |
| px = x_index * merge_patch_size + merge_patch_size / 2 |
| py = y_index * merge_patch_size + merge_patch_size / 2 |
| return px, py |
|
|
| class LazySupervisedDataset(Dataset): |
| def __init__( |
| self, |
| tokenizer: transformers.PreTrainedTokenizer, |
| processor: transformers.ProcessorMixin, |
| data_path: str, |
| data_args, |
| ): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.processor = processor |
| self.list_data_dict = [] |
| self.list_image_path = [] |
| self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0] |
| self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0] |
| self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] |
|
|
| |
| if "{" in data_path and "}" in data_path: |
| base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups() |
| file_names = file_pattern.split(",") |
| rank0_print(f"Loading {file_names} from {base_path}") |
| data_args.dataset_paths = [] |
| for file_name in file_names: |
| data_args.dataset_paths.append(f"{base_path}{file_name}.json") |
| full_path = f"{base_path}{file_name}.json" |
| rank0_print(f"Loading {full_path}") |
| with open(full_path) as file: |
| cur_data_dict = json.load(file) |
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}") |
| self.list_data_dict.extend(cur_data_dict) |
| elif data_path.endswith(".yaml"): |
| with open(data_path) as file: |
| yaml_data = yaml.safe_load(file) |
| datasets = yaml_data.get("datasets") |
| |
| |
| |
| |
| |
| |
| |
| |
| data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] |
| for dataset in datasets: |
| json_path = dataset.get("json_path") |
| sampling_strategy = dataset.get("sampling_strategy", "all") |
| images_folder = dataset.get("images_folder") |
| sampling_number = None |
|
|
| rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy") |
|
|
| if json_path.endswith(".jsonl"): |
| cur_data_dict = [] |
| with open(json_path) as json_file: |
| for line in json_file: |
| cur_data_dict.append(json.loads(line.strip())) |
| elif json_path.endswith(".json"): |
| |
| |
| with open(json_path) as json_file: |
| cur_data_dict = json.load(json_file) |
| else: |
| raise ValueError(f"Unsupported file type: {json_path}") |
|
|
| if ":" in sampling_strategy: |
| sampling_strategy, sampling_number = sampling_strategy.split(":") |
| if "%" in sampling_number: |
| sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) |
| else: |
| sampling_number = int(sampling_number) |
|
|
| |
| if sampling_strategy == "first" and sampling_number is not None: |
| cur_data_dict = cur_data_dict[:sampling_number] |
| elif sampling_strategy == "end" and sampling_number is not None: |
| cur_data_dict = cur_data_dict[-sampling_number:] |
| elif sampling_strategy == "random" and sampling_number is not None: |
| random.shuffle(cur_data_dict) |
| cur_data_dict = cur_data_dict[:sampling_number] |
|
|
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}") |
| self.list_data_dict.extend(cur_data_dict) |
| self.list_image_path.extend([images_folder] * len(cur_data_dict)) |
| else: |
| data_args.dataset_paths = [data_path] |
| rank0_print(f"Loading {data_path}") |
| with open(data_path) as file: |
| cur_data_dict = json.load(file) |
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}") |
| self.list_data_dict.extend(cur_data_dict) |
| self.list_image_path.extend([""] * len(cur_data_dict)) |
|
|
| rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}") |
| rank0_print("Formatting inputs...Skip in lazy mode") |
| self.tokenizer = tokenizer |
| self.data_args = data_args |
|
|
| def __len__(self): |
| return len(self.list_data_dict) |
|
|
| @property |
| def lengths(self): |
| length_list = [] |
| for sample in self.list_data_dict: |
| img_tokens = ( |
| 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 |
| ) |
| length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) |
| return length_list |
|
|
| @property |
| def modality_lengths(self): |
| length_list = [] |
| for sample in self.list_data_dict: |
| cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) |
| assert cur_len > 0, f"Conversation length is 0 for {sample}" |
|
|
| img_tokens = ( |
| 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 |
| ) |
|
|
| if "image" in sample or "video" in sample or self.data_args.early_mix_text: |
| length_list.append(cur_len + img_tokens) |
| else: |
| length_list.append(-cur_len) |
| return length_list |
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| sample = self._get_item(i) |
| if sample is None: |
| new_index = random.randint(0, len(self.list_data_dict) - 1) |
| return self.__getitem__(new_index) |
| else: |
| return sample |
| try: |
| sample = self._get_item(i) |
| if sample is None: |
| new_index = random.randint(0, len(self.list_data_dict) - 1) |
| return self.__getitem__(new_index) |
| except Exception as e: |
| print(f"Failed to fetch sample {i}. Exception:", e) |
| new_index = random.randint(0, len(self.list_data_dict) - 1) |
| return self.__getitem__(new_index) |
| return sample |
|
|
| def _get_item(self, i) -> Dict[str, torch.Tensor]: |
| sources = self.list_data_dict[i] |
| image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i]) |
|
|
| if "image" in sources: |
| image_file = self.list_data_dict[i]["image"] |
| if type(image_file) is list: |
| image_list = [os.path.join(image_path, image_file) for image_file in image_file] |
| else: |
| image_list = [os.path.join(image_path, image_file)] |
|
|
| sources = copy.deepcopy(sources["conversations"]) |
| elif "video" in sources: |
| raise NotImplementedError("Video is not supported for Qwen2VL") |
| else: |
| sources = copy.deepcopy(sources["conversations"]) |
|
|
| item_id = self.list_data_dict[i].get("id", i) |
|
|
| data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id) |
| if isinstance(i, int): |
| data_dict = { |
| "input_ids": data_dict["input_ids"][0], |
| "labels": data_dict["labels"][0], |
| "coordinates": data_dict["coordinates"][0], |
| "visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0], |
| "pixel_values": data_dict["pixel_values"], |
| "image_grid_thw": data_dict["image_grid_thw"], |
| "multi_patch_labels": data_dict["multi_patch_labels"][0], |
| } |
|
|
| data_dict["id"] = item_id |
|
|
| |
| n_image_tokens = ( |
| data_dict["image_grid_thw"][0][0] * |
| data_dict["image_grid_thw"][0][1] * |
| data_dict["image_grid_thw"][0][2] / |
| self.processor.image_processor.merge_size / |
| self.processor.image_processor.merge_size |
| ) |
| if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length: |
| rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}") |
| return None |
|
|
| return data_dict |
|
|
| def preprocess_qwen2vl( |
| self, |
| source, |
| tokenizer: transformers.PreTrainedTokenizer, |
| processor: transformers.ProcessorMixin, |
| image: list, |
| system_message: str = grounding_system_message, |
| agent_mode: bool = True, |
| chat_template: str = chat_template, |
| assistant_template: str = assistant_template, |
| id: int = None, |
| ) -> Dict: |
| roles = {"human": "user", "gpt": "assistant", "system": "system"} |
| assistant_template = assistant_template if agent_mode else chat_template |
| processor.tokenizer = tokenizer |
| assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS |
|
|
| |
| pixel_values, image_grid_thw = None, None |
|
|
| input_id, target = [], [] |
| coordinates = [] |
| visual_token_indices_of_coordinates = [] |
| multi_patch_labels = [] |
| |
| image_list = [] |
| image_index = 0 |
|
|
| |
| if roles[source[0]["from"]] == "system": |
| system_message = source[0]["value"] |
| source = source[1:self.data_args.max_conv_turns] |
| |
| system_input_id = tokenizer.apply_chat_template( |
| conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}], |
| chat_template=chat_template, |
| ) |
| input_id += system_input_id |
| target += [IGNORE_INDEX] * len(system_input_id) |
|
|
| |
| for conv in source: |
| |
| try: |
| role = conv["role"] |
| content = conv["content"] |
| except Exception: |
| role = conv["from"] |
| content = conv["value"] |
| role = roles.get(role, role) |
|
|
| |
| image_count = content.count(DEFAULT_IMAGE_TOKEN) |
| if image_count > 0: |
| assert role == "user", "Images are only supported for user messages" |
| |
| image_placeholders = [] |
| for _ in range(image_count): |
| image_placeholders.append({ |
| "type": "image", |
| "image": image[image_index], |
| "min_pixels": self.processor.image_processor.min_pixels, |
| "max_pixels": self.processor.image_processor.max_pixels, |
| }) |
| image_index += 1 |
|
|
| content = content.replace(DEFAULT_IMAGE_TOKEN, "") |
| conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]} |
|
|
| image_inputs, _ = process_vision_info([conv]) |
| image_list.extend(image_inputs) |
| |
| templated_conv = tokenizer.apply_chat_template( |
| conversation=[conv], chat_template=chat_template, tokenize=False |
| ) |
| inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt") |
|
|
| if pixel_values is None and image_grid_thw is None: |
| pixel_values = inputs["pixel_values"] |
| image_grid_thw = inputs["image_grid_thw"] |
| else: |
| pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0) |
| image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0) |
| else: |
| if role in ["user", "system"]: |
| conv = {"role": role, "content": [{"type": "text", "text": content}]} |
| else: |
| conv = { |
| "role": role, |
| "content": [{"type": "text", "text": content}], |
| "recipient": conv.get("recipient", "os"), |
| "end_turn": conv.get("end_turn", True), |
| "bbox_gt": conv.get("bbox_gt", None), |
| } |
| if conv["recipient"] == "os": |
| if len(image_inputs) == 0: |
| raise ValueError("No image found for visual grounding") |
| |
| text, coord = reformat_coordinates(conv["content"][0]["text"]) |
| conv["content"][0]["text"] = text |
| |
|
|
| |
| coordinates.extend(coord) |
| for (point_x, point_y) in coord: |
| visual_token_index = get_token_index( |
| processor.image_processor, |
| image_list, |
| point_x, |
| point_y |
| ) |
| |
| |
| |
| |
| |
| |
| |
| visual_token_indices_of_coordinates.append(visual_token_index) |
|
|
| if conv["bbox_gt"] is not None: |
| patch_mask = get_multi_patch_labels( |
| processor.image_processor, |
| image_list, |
| conv["bbox_gt"] |
| ) |
| multi_patch_labels.append(patch_mask) |
|
|
| templated_conv = tokenizer.apply_chat_template( |
| conversation=[conv], |
| chat_template=assistant_template, |
| tokenize=False, |
| ) |
| inputs = processor(text=[templated_conv], return_tensors="pt") |
|
|
| encode_id = inputs.input_ids[0].tolist() |
|
|
| input_id += encode_id |
| if role in ["user", "system"]: |
| target += [IGNORE_INDEX] * len(encode_id) |
| else: |
| target += encode_id |
|
|
| assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
|
| |
| target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target] |
|
|
| input_ids = torch.tensor([input_id], dtype=torch.long) |
| targets = torch.tensor([target], dtype=torch.long) |
| visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None] |
| coordinates = [coordinates] if len(coordinates) > 0 else [None] |
|
|
| |
| if len(multi_patch_labels) > 0: |
| multi_patch_labels = [torch.stack(multi_patch_labels)] |
| else: |
| multi_patch_labels = [None] |
|
|
| data_dict = { |
| "input_ids": input_ids, |
| "labels": targets, |
| } |
|
|
| if pixel_values is not None: |
| data_dict["pixel_values"] = pixel_values |
| data_dict["image_grid_thw"] = image_grid_thw |
| |
| |
| |
| data_dict["coordinates"] = coordinates |
| data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates |
| data_dict["multi_patch_labels"] = multi_patch_labels |
| |
| return data_dict |
|
|