| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional, Union |
| | from types import SimpleNamespace |
| | from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast |
| | from functools import partial, lru_cache |
| | from transformers.image_processing_utils import BatchFeature |
| | from transformers.image_utils import ( |
| | ChannelDimension, |
| | SizeDict, |
| | make_flat_list_of_images, |
| | valid_images, |
| | pil_torch_interpolation_mapping, |
| | ) |
| | from torchvision.transforms.v2 import functional as F |
| | import torch |
| | from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def rescale(image, scale): |
| | return image * scale |
| |
|
| |
|
| | def normalize(image, mean, std): |
| | return F.normalize(image, mean, std) |
| |
|
| |
|
| | @lru_cache(maxsize=10) |
| | def _fuse_mean_std_and_rescale_factor( |
| | do_normalize: Optional[bool] = None, |
| | image_mean: Optional[Union[float, list[float]]] = None, |
| | image_std: Optional[Union[float, list[float]]] = None, |
| | do_rescale: Optional[bool] = None, |
| | rescale_factor: Optional[float] = None, |
| | device: Optional["torch.device"] = None, |
| | ) -> tuple: |
| | if do_rescale and do_normalize: |
| | |
| | image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) |
| | image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) |
| | do_rescale = False |
| | return image_mean, image_std, do_rescale |
| |
|
| |
|
| | def rescale_and_normalize( |
| | images: "torch.Tensor", |
| | do_rescale: bool, |
| | rescale_factor: float, |
| | do_normalize: bool, |
| | image_mean: Union[float, list[float]], |
| | image_std: Union[float, list[float]], |
| | ) -> "torch.Tensor": |
| | """ |
| | Rescale and normalize images. |
| | """ |
| | image_mean, image_std, do_rescale = _fuse_mean_std_and_rescale_factor( |
| | do_normalize=do_normalize, |
| | image_mean=image_mean, |
| | image_std=image_std, |
| | do_rescale=do_rescale, |
| | rescale_factor=rescale_factor, |
| | device=images.device, |
| | ) |
| | |
| | if do_normalize: |
| | images = normalize(images.to(dtype=torch.float32), image_mean, image_std) |
| | elif do_rescale: |
| | images = rescale(images, rescale_factor) |
| | images = images.to(OpenPanguVLImageProcessorFast.dtype) |
| |
|
| | return images |
| |
|
| | |
| | from collections import defaultdict |
| | def _group_images_by_shape(nested_images, is_nested: bool = False): |
| | """Helper function to flatten a single level of nested image structures and group by shape.""" |
| | grouped_images = defaultdict(list) |
| | grouped_images_index = {} |
| | nested_images = [nested_images] if not is_nested else nested_images |
| | for i, sublist in enumerate(nested_images): |
| | for j, image in enumerate(sublist): |
| | key = (i, j) if is_nested else j |
| | shape = image.shape[1:] |
| | grouped_images[shape].append(image) |
| | grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) |
| |
|
| | return grouped_images, grouped_images_index |
| |
|
| |
|
| | def _reconstruct_nested_structure(indices, processed_images): |
| | """Helper function to reconstruct a single level nested structure.""" |
| | |
| | max_outer_idx = max(idx[0] for idx in indices.keys()) |
| |
|
| | |
| | result = [None] * (max_outer_idx + 1) |
| |
|
| | |
| | nested_indices = defaultdict(list) |
| | for i, j in indices.keys(): |
| | nested_indices[i].append(j) |
| |
|
| | for i in range(max_outer_idx + 1): |
| | if i in nested_indices: |
| | inner_max_idx = max(nested_indices[i]) |
| | inner_list = [None] * (inner_max_idx + 1) |
| | for j in range(inner_max_idx + 1): |
| | if (i, j) in indices: |
| | shape, idx = indices[(i, j)] |
| | inner_list[j] = processed_images[shape][idx] |
| | result[i] = inner_list |
| |
|
| | return result |
| |
|
| |
|
| | def group_images_by_shape( |
| | images: Union[list["torch.Tensor"], "torch.Tensor"], |
| | disable_grouping: bool, |
| | is_nested: bool = False, |
| | ) -> tuple[ |
| | dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] |
| | ]: |
| | |
| | if disable_grouping is None: |
| | device = images[0][0].device if is_nested else images[0].device |
| | disable_grouping = device == "cpu" |
| |
|
| | if disable_grouping: |
| | if is_nested: |
| | return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { |
| | (i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) |
| | } |
| | else: |
| | return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} |
| |
|
| | |
| | grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested) |
| |
|
| | |
| | grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()} |
| |
|
| | return grouped_images, grouped_images_index |
| |
|
| |
|
| | def reorder_images( |
| | processed_images: dict[tuple[int, int], "torch.Tensor"], |
| | grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]], |
| | is_nested: bool = False, |
| | ) -> Union[list["torch.Tensor"], "torch.Tensor"]: |
| | if not is_nested: |
| | return [ |
| | processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] |
| | for i in range(len(grouped_images_index)) |
| | ] |
| |
|
| | return _reconstruct_nested_structure(grouped_images_index, processed_images) |
| |
|
| |
|
| | class OpenPanguVLImageProcessorFast(Qwen2VLImageProcessorFast): |
| | temporal_patch_size = 1 |
| | min_pxl = 28 |
| | min_edge = 56 |
| | dtype = torch.bfloat16 |
| |
|
| | def _prepare_input_images( |
| | self, |
| | images, |
| | do_convert_rgb, |
| | input_data_format, |
| | device, |
| | ) -> list["torch.Tensor"]: |
| | """ |
| | Prepare the input images for processing. |
| | """ |
| | images = self._prepare_images_structure(images) |
| | process_image_fn = partial( |
| | self._process_image, |
| | do_convert_rgb=do_convert_rgb, |
| | input_data_format=input_data_format, |
| | device=device, |
| | ) |
| |
|
| | processed_images = [] |
| | for image in images: |
| | if image.size[0] <= OpenPanguVLImageProcessorFast.min_pxl or image.size[1] <= OpenPanguVLImageProcessorFast.min_pxl: |
| | |
| | if image.size[0] >= image.size[1]: |
| | aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[1] |
| | new_image_height = OpenPanguVLImageProcessorFast.min_edge |
| | new_image_width = int(aspect_ratio * image.size[0]) |
| | else: |
| | aspect_ratio = OpenPanguVLImageProcessorFast.min_edge * 1.0 / image.size[0] |
| | new_image_height = int(aspect_ratio * image.size[1]) |
| | new_image_width = OpenPanguVLImageProcessorFast.min_edge |
| | image = image.resize((new_image_width, new_image_height)) |
| |
|
| | processed_images.append(process_image_fn(image)) |
| | return processed_images |
| |
|
| | def preprocess( |
| | self, |
| | images = None, |
| | videos = None, |
| | do_resize = None, |
| | size = None, |
| | resample = None, |
| | do_rescale = None, |
| | rescale_factor = None, |
| | do_normalize = None, |
| | image_mean = None, |
| | image_std = None, |
| | min_pixels = None, |
| | max_pixels = None, |
| | patch_size = None, |
| | temporal_patch_size = None, |
| | merge_size = None, |
| | do_convert_rgb = None, |
| | return_tensors = None, |
| | data_format = ChannelDimension.FIRST, |
| | input_data_format = None, |
| | device = None, |
| | disable_grouping = False, |
| | **kwargs, |
| | ): |
| | temporal_patch_size=OpenPanguVLImageProcessorFast.temporal_patch_size |
| | params = self._resolve_preprocess_params( |
| | do_resize=do_resize, |
| | size=size, |
| | min_pixels=min_pixels, |
| | max_pixels=max_pixels, |
| | resample=resample, |
| | do_rescale=do_rescale, |
| | rescale_factor=rescale_factor, |
| | do_normalize=do_normalize, |
| | image_mean=image_mean, |
| | image_std=image_std, |
| | patch_size=patch_size, |
| | temporal_patch_size=temporal_patch_size, |
| | merge_size=merge_size, |
| | do_convert_rgb=do_convert_rgb, |
| | ) |
| |
|
| | data = self._process_images( |
| | images, |
| | params, |
| | input_data_format, |
| | device, |
| | disable_grouping, |
| | return_tensors |
| | ) |
| |
|
| | return data |
| |
|
| | def _resolve_preprocess_params(self, **kwargs): |
| | params = SimpleNamespace() |
| | for key, value in kwargs.items(): |
| | setattr(params, key, value if value is not None else getattr(self, key)) |
| | if params.size is None: |
| | params.size = {"shortest_edge": params.min_pixels, "longest_edge": params.max_pixels} |
| | params.size = SizeDict(**params.size) |
| | params.image_mean = tuple(params.image_mean) if params.image_mean else None |
| | params.image_std = tuple(params.image_std) if params.image_std else None |
| | return params |
| |
|
| | def _process_images(self, images, params, input_data_format, device, disable_grouping, return_tensors): |
| | images = make_flat_list_of_images(images) |
| | if not valid_images(images): |
| | raise ValueError("Invalid image type.") |
| |
|
| | images = self._prepare_input_images( |
| | images=images, |
| | do_convert_rgb=params.do_convert_rgb, |
| | input_data_format=input_data_format, |
| | device=device, |
| | ) |
| |
|
| | data = self._preprocess( |
| | images=images, |
| | do_resize=params.do_resize, |
| | size=params.size, |
| | interpolation=pil_torch_interpolation_mapping.get(params.resample, params.resample), |
| | do_rescale=params.do_rescale, |
| | rescale_factor=params.rescale_factor, |
| | do_normalize=params.do_normalize, |
| | image_mean=params.image_mean, |
| | image_std=params.image_std, |
| | patch_size=params.patch_size, |
| | temporal_patch_size=params.temporal_patch_size, |
| | merge_size=params.merge_size, |
| | do_convert_rgb=params.do_convert_rgb, |
| | input_data_format=input_data_format, |
| | device=device, |
| | disable_grouping=disable_grouping, |
| | return_tensors=return_tensors, |
| | ) |
| | |
| | return data |
| |
|
| | def _preprocess( |
| | self, |
| | images: list["torch.Tensor"], |
| | do_resize: bool, |
| | size: SizeDict, |
| | interpolation: Optional["F.InterpolationMode"], |
| | do_rescale: bool, |
| | rescale_factor: float, |
| | do_normalize: bool, |
| | image_mean: Optional[Union[float, list[float]]], |
| | image_std: Optional[Union[float, list[float]]], |
| | patch_size: int, |
| | temporal_patch_size: int, |
| | merge_size: int, |
| | disable_grouping: Optional[bool], |
| | return_tensors, |
| | **kwargs, |
| | ): |
| | |
| | grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
| | resized_images_grouped = {} |
| | for shape, stacked_images in grouped_images.items(): |
| | height, width = stacked_images.shape[-2:] |
| | if do_resize: |
| | resized_height, resized_width = smart_resize( |
| | height, |
| | width, |
| | factor=patch_size * merge_size, |
| | min_pixels=size["shortest_edge"], |
| | max_pixels=size["longest_edge"], |
| | ) |
| | stacked_images = self.resize( |
| | image=stacked_images, |
| | size=SizeDict(height=resized_height, width=resized_width), |
| | interpolation=interpolation, |
| | ) |
| | resized_images_grouped[shape] = stacked_images |
| | resized_images = reorder_images(resized_images_grouped, grouped_images_index) |
| |
|
| | |
| | |
| | grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) |
| | processed_images_grouped = {} |
| | processed_grids = {} |
| | for shape, stacked_images in grouped_images.items(): |
| | resized_height, resized_width = stacked_images.shape[-2:] |
| | |
| | |
| | |
| | |
| | patches = stacked_images |
| | if patches.ndim == 4: |
| | |
| | patches = patches.unsqueeze(1) |
| | if patches.shape[1] % temporal_patch_size != 0: |
| | repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) |
| | patches = torch.cat([patches, repeats], dim=1) |
| | batch_size, grid_t, channel = patches.shape[:3] |
| | grid_t = grid_t // temporal_patch_size |
| | grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
| |
|
| | patches = patches.view( |
| | batch_size, |
| | grid_t, |
| | temporal_patch_size, |
| | channel, |
| | grid_h // merge_size, |
| | merge_size, |
| | patch_size, |
| | grid_w // merge_size, |
| | merge_size, |
| | patch_size, |
| | ) |
| | |
| | |
| | patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) |
| | flatten_patches = patches.reshape( |
| | batch_size, |
| | grid_t * grid_h * grid_w, |
| | channel * temporal_patch_size * patch_size * patch_size, |
| | ) |
| |
|
| | processed_images_grouped[shape] = flatten_patches |
| | processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size |
| |
|
| | processed_images = reorder_images(processed_images_grouped, grouped_images_index) |
| | processed_grids = reorder_images(processed_grids, grouped_images_index) |
| | pixel_values = torch.cat(processed_images, dim=0) |
| | image_grid_thw = torch.tensor(processed_grids) |
| |
|
| | return BatchFeature( |
| | data={"pixel_values": pixel_values, |
| | "image_grid_thw": image_grid_thw}, tensor_type=return_tensors |
| | ) |